mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-19 19:03:51 +00:00
All the unet weights should now be initialized with the right dtype.
This commit is contained in:
parent
cf3974c829
commit
ae43f09ef7
@ -51,9 +51,9 @@ def init_(tensor):
|
|||||||
|
|
||||||
# feedforward
|
# feedforward
|
||||||
class GEGLU(nn.Module):
|
class GEGLU(nn.Module):
|
||||||
def __init__(self, dim_in, dim_out):
|
def __init__(self, dim_in, dim_out, dtype=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.proj = comfy.ops.Linear(dim_in, dim_out * 2)
|
self.proj = comfy.ops.Linear(dim_in, dim_out * 2, dtype=dtype)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x, gate = self.proj(x).chunk(2, dim=-1)
|
x, gate = self.proj(x).chunk(2, dim=-1)
|
||||||
@ -68,7 +68,7 @@ class FeedForward(nn.Module):
|
|||||||
project_in = nn.Sequential(
|
project_in = nn.Sequential(
|
||||||
comfy.ops.Linear(dim, inner_dim, dtype=dtype),
|
comfy.ops.Linear(dim, inner_dim, dtype=dtype),
|
||||||
nn.GELU()
|
nn.GELU()
|
||||||
) if not glu else GEGLU(dim, inner_dim)
|
) if not glu else GEGLU(dim, inner_dim, dtype=dtype)
|
||||||
|
|
||||||
self.net = nn.Sequential(
|
self.net = nn.Sequential(
|
||||||
project_in,
|
project_in,
|
||||||
@ -89,8 +89,8 @@ def zero_module(module):
|
|||||||
return module
|
return module
|
||||||
|
|
||||||
|
|
||||||
def Normalize(in_channels):
|
def Normalize(in_channels, dtype=None):
|
||||||
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
class SpatialSelfAttention(nn.Module):
|
class SpatialSelfAttention(nn.Module):
|
||||||
@ -594,7 +594,7 @@ class SpatialTransformer(nn.Module):
|
|||||||
context_dim = [context_dim]
|
context_dim = [context_dim]
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
inner_dim = n_heads * d_head
|
inner_dim = n_heads * d_head
|
||||||
self.norm = Normalize(in_channels)
|
self.norm = Normalize(in_channels, dtype=dtype)
|
||||||
if not use_linear:
|
if not use_linear:
|
||||||
self.proj_in = nn.Conv2d(in_channels,
|
self.proj_in = nn.Conv2d(in_channels,
|
||||||
inner_dim,
|
inner_dim,
|
||||||
|
@ -111,14 +111,14 @@ class Upsample(nn.Module):
|
|||||||
upsampling occurs in the inner-two dimensions.
|
upsampling occurs in the inner-two dimensions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
|
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.channels = channels
|
self.channels = channels
|
||||||
self.out_channels = out_channels or channels
|
self.out_channels = out_channels or channels
|
||||||
self.use_conv = use_conv
|
self.use_conv = use_conv
|
||||||
self.dims = dims
|
self.dims = dims
|
||||||
if use_conv:
|
if use_conv:
|
||||||
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
|
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding, dtype=dtype)
|
||||||
|
|
||||||
def forward(self, x, output_shape=None):
|
def forward(self, x, output_shape=None):
|
||||||
assert x.shape[1] == self.channels
|
assert x.shape[1] == self.channels
|
||||||
@ -160,7 +160,7 @@ class Downsample(nn.Module):
|
|||||||
downsampling occurs in the inner-two dimensions.
|
downsampling occurs in the inner-two dimensions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1):
|
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.channels = channels
|
self.channels = channels
|
||||||
self.out_channels = out_channels or channels
|
self.out_channels = out_channels or channels
|
||||||
@ -169,7 +169,7 @@ class Downsample(nn.Module):
|
|||||||
stride = 2 if dims != 3 else (1, 2, 2)
|
stride = 2 if dims != 3 else (1, 2, 2)
|
||||||
if use_conv:
|
if use_conv:
|
||||||
self.op = conv_nd(
|
self.op = conv_nd(
|
||||||
dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
|
dims, self.channels, self.out_channels, 3, stride=stride, padding=padding, dtype=dtype
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
assert self.channels == self.out_channels
|
assert self.channels == self.out_channels
|
||||||
@ -220,7 +220,7 @@ class ResBlock(TimestepBlock):
|
|||||||
self.use_scale_shift_norm = use_scale_shift_norm
|
self.use_scale_shift_norm = use_scale_shift_norm
|
||||||
|
|
||||||
self.in_layers = nn.Sequential(
|
self.in_layers = nn.Sequential(
|
||||||
normalization(channels),
|
normalization(channels, dtype=dtype),
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
conv_nd(dims, channels, self.out_channels, 3, padding=1, dtype=dtype),
|
conv_nd(dims, channels, self.out_channels, 3, padding=1, dtype=dtype),
|
||||||
)
|
)
|
||||||
@ -228,11 +228,11 @@ class ResBlock(TimestepBlock):
|
|||||||
self.updown = up or down
|
self.updown = up or down
|
||||||
|
|
||||||
if up:
|
if up:
|
||||||
self.h_upd = Upsample(channels, False, dims)
|
self.h_upd = Upsample(channels, False, dims, dtype=dtype)
|
||||||
self.x_upd = Upsample(channels, False, dims)
|
self.x_upd = Upsample(channels, False, dims, dtype=dtype)
|
||||||
elif down:
|
elif down:
|
||||||
self.h_upd = Downsample(channels, False, dims)
|
self.h_upd = Downsample(channels, False, dims, dtype=dtype)
|
||||||
self.x_upd = Downsample(channels, False, dims)
|
self.x_upd = Downsample(channels, False, dims, dtype=dtype)
|
||||||
else:
|
else:
|
||||||
self.h_upd = self.x_upd = nn.Identity()
|
self.h_upd = self.x_upd = nn.Identity()
|
||||||
|
|
||||||
@ -240,11 +240,11 @@ class ResBlock(TimestepBlock):
|
|||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
linear(
|
linear(
|
||||||
emb_channels,
|
emb_channels,
|
||||||
2 * self.out_channels if use_scale_shift_norm else self.out_channels,
|
2 * self.out_channels if use_scale_shift_norm else self.out_channels, dtype=dtype
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
self.out_layers = nn.Sequential(
|
self.out_layers = nn.Sequential(
|
||||||
normalization(self.out_channels),
|
normalization(self.out_channels, dtype=dtype),
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
nn.Dropout(p=dropout),
|
nn.Dropout(p=dropout),
|
||||||
zero_module(
|
zero_module(
|
||||||
@ -604,6 +604,7 @@ class UNetModel(nn.Module):
|
|||||||
dims=dims,
|
dims=dims,
|
||||||
use_checkpoint=use_checkpoint,
|
use_checkpoint=use_checkpoint,
|
||||||
use_scale_shift_norm=use_scale_shift_norm,
|
use_scale_shift_norm=use_scale_shift_norm,
|
||||||
|
dtype=self.dtype
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
ch = mult * model_channels
|
ch = mult * model_channels
|
||||||
@ -651,10 +652,11 @@ class UNetModel(nn.Module):
|
|||||||
use_checkpoint=use_checkpoint,
|
use_checkpoint=use_checkpoint,
|
||||||
use_scale_shift_norm=use_scale_shift_norm,
|
use_scale_shift_norm=use_scale_shift_norm,
|
||||||
down=True,
|
down=True,
|
||||||
|
dtype=self.dtype
|
||||||
)
|
)
|
||||||
if resblock_updown
|
if resblock_updown
|
||||||
else Downsample(
|
else Downsample(
|
||||||
ch, conv_resample, dims=dims, out_channels=out_ch
|
ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -679,6 +681,7 @@ class UNetModel(nn.Module):
|
|||||||
dims=dims,
|
dims=dims,
|
||||||
use_checkpoint=use_checkpoint,
|
use_checkpoint=use_checkpoint,
|
||||||
use_scale_shift_norm=use_scale_shift_norm,
|
use_scale_shift_norm=use_scale_shift_norm,
|
||||||
|
dtype=self.dtype
|
||||||
),
|
),
|
||||||
AttentionBlock(
|
AttentionBlock(
|
||||||
ch,
|
ch,
|
||||||
@ -698,6 +701,7 @@ class UNetModel(nn.Module):
|
|||||||
dims=dims,
|
dims=dims,
|
||||||
use_checkpoint=use_checkpoint,
|
use_checkpoint=use_checkpoint,
|
||||||
use_scale_shift_norm=use_scale_shift_norm,
|
use_scale_shift_norm=use_scale_shift_norm,
|
||||||
|
dtype=self.dtype
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
self._feature_size += ch
|
self._feature_size += ch
|
||||||
@ -715,6 +719,7 @@ class UNetModel(nn.Module):
|
|||||||
dims=dims,
|
dims=dims,
|
||||||
use_checkpoint=use_checkpoint,
|
use_checkpoint=use_checkpoint,
|
||||||
use_scale_shift_norm=use_scale_shift_norm,
|
use_scale_shift_norm=use_scale_shift_norm,
|
||||||
|
dtype=self.dtype
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
ch = model_channels * mult
|
ch = model_channels * mult
|
||||||
@ -758,18 +763,19 @@ class UNetModel(nn.Module):
|
|||||||
use_checkpoint=use_checkpoint,
|
use_checkpoint=use_checkpoint,
|
||||||
use_scale_shift_norm=use_scale_shift_norm,
|
use_scale_shift_norm=use_scale_shift_norm,
|
||||||
up=True,
|
up=True,
|
||||||
|
dtype=self.dtype
|
||||||
)
|
)
|
||||||
if resblock_updown
|
if resblock_updown
|
||||||
else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
|
else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype)
|
||||||
)
|
)
|
||||||
ds //= 2
|
ds //= 2
|
||||||
self.output_blocks.append(TimestepEmbedSequential(*layers))
|
self.output_blocks.append(TimestepEmbedSequential(*layers))
|
||||||
self._feature_size += ch
|
self._feature_size += ch
|
||||||
|
|
||||||
self.out = nn.Sequential(
|
self.out = nn.Sequential(
|
||||||
normalization(ch),
|
normalization(ch, dtype=self.dtype),
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
|
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1, dtype=self.dtype)),
|
||||||
)
|
)
|
||||||
if self.predict_codebook_ids:
|
if self.predict_codebook_ids:
|
||||||
self.id_predictor = nn.Sequential(
|
self.id_predictor = nn.Sequential(
|
||||||
|
@ -206,13 +206,13 @@ def mean_flat(tensor):
|
|||||||
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
||||||
|
|
||||||
|
|
||||||
def normalization(channels):
|
def normalization(channels, dtype=None):
|
||||||
"""
|
"""
|
||||||
Make a standard normalization layer.
|
Make a standard normalization layer.
|
||||||
:param channels: number of input channels.
|
:param channels: number of input channels.
|
||||||
:return: an nn.Module for normalization.
|
:return: an nn.Module for normalization.
|
||||||
"""
|
"""
|
||||||
return GroupNorm32(32, channels)
|
return GroupNorm32(32, channels, dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
|
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
|
||||||
|
Loading…
Reference in New Issue
Block a user