This isn't needed for inference.

This commit is contained in:
comfyanonymous 2023-06-14 13:05:08 -04:00
parent fa2cca056c
commit 9d54066ebc

View File

@ -610,13 +610,12 @@ class SpatialTransformer(nn.Module):
for d in range(depth)] for d in range(depth)]
) )
if not use_linear: if not use_linear:
self.proj_out = zero_module(nn.Conv2d(inner_dim, self.proj_out = nn.Conv2d(inner_dim,in_channels,
in_channels,
kernel_size=1, kernel_size=1,
stride=1, stride=1,
padding=0)) padding=0)
else: else:
self.proj_out = zero_module(comfy.ops.Linear(in_channels, inner_dim)) self.proj_out = comfy.ops.Linear(in_channels, inner_dim)
self.use_linear = use_linear self.use_linear = use_linear
def forward(self, x, context=None, transformer_options={}): def forward(self, x, context=None, transformer_options={}):