Make unet work with any input shape.

This commit is contained in:
comfyanonymous 2023-05-02 13:31:43 -04:00
parent a307c3f12c
commit 66c8aa5c3e

View File

@ -76,12 +76,14 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
support it as an extra input.
"""
def forward(self, x, emb, context=None, transformer_options={}):
def forward(self, x, emb, context=None, transformer_options={}, output_shape=None):
for layer in self:
if isinstance(layer, TimestepBlock):
x = layer(x, emb)
elif isinstance(layer, SpatialTransformer):
x = layer(x, context, transformer_options)
elif isinstance(layer, Upsample):
x = layer(x, output_shape=output_shape)
else:
x = layer(x)
return x
@ -105,14 +107,21 @@ class Upsample(nn.Module):
if use_conv:
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
def forward(self, x):
def forward(self, x, output_shape=None):
print("upsample", output_shape)
assert x.shape[1] == self.channels
if self.dims == 3:
x = F.interpolate(
x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
)
shape = [x.shape[2], x.shape[3] * 2, x.shape[4] * 2]
if output_shape is not None:
shape[1] = output_shape[3]
shape[2] = output_shape[4]
else:
x = F.interpolate(x, scale_factor=2, mode="nearest")
shape = [x.shape[2] * 2, x.shape[3] * 2]
if output_shape is not None:
shape[0] = output_shape[2]
shape[1] = output_shape[3]
x = F.interpolate(x, size=shape, mode="nearest")
if self.use_conv:
x = self.conv(x)
return x
@ -813,9 +822,14 @@ class UNetModel(nn.Module):
ctrl = control['output'].pop()
if ctrl is not None:
hsp += ctrl
h = th.cat([h, hsp], dim=1)
del hsp
h = module(h, emb, context, transformer_options)
if len(hs) > 0:
output_shape = hs[-1].shape
else:
output_shape = None
h = module(h, emb, context, transformer_options, output_shape)
h = h.type(x.dtype)
if self.predict_codebook_ids:
return self.id_predictor(h)