Make unet work with any input shape.

This commit is contained in:
comfyanonymous 2023-05-02 13:31:43 -04:00
parent a307c3f12c
commit 66c8aa5c3e

View File

@ -76,12 +76,14 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
support it as an extra input. support it as an extra input.
""" """
def forward(self, x, emb, context=None, transformer_options={}): def forward(self, x, emb, context=None, transformer_options={}, output_shape=None):
for layer in self: for layer in self:
if isinstance(layer, TimestepBlock): if isinstance(layer, TimestepBlock):
x = layer(x, emb) x = layer(x, emb)
elif isinstance(layer, SpatialTransformer): elif isinstance(layer, SpatialTransformer):
x = layer(x, context, transformer_options) x = layer(x, context, transformer_options)
elif isinstance(layer, Upsample):
x = layer(x, output_shape=output_shape)
else: else:
x = layer(x) x = layer(x)
return x return x
@ -105,14 +107,21 @@ class Upsample(nn.Module):
if use_conv: if use_conv:
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding) self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
def forward(self, x): def forward(self, x, output_shape=None):
print("upsample", output_shape)
assert x.shape[1] == self.channels assert x.shape[1] == self.channels
if self.dims == 3: if self.dims == 3:
x = F.interpolate( shape = [x.shape[2], x.shape[3] * 2, x.shape[4] * 2]
x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" if output_shape is not None:
) shape[1] = output_shape[3]
shape[2] = output_shape[4]
else: else:
x = F.interpolate(x, scale_factor=2, mode="nearest") shape = [x.shape[2] * 2, x.shape[3] * 2]
if output_shape is not None:
shape[0] = output_shape[2]
shape[1] = output_shape[3]
x = F.interpolate(x, size=shape, mode="nearest")
if self.use_conv: if self.use_conv:
x = self.conv(x) x = self.conv(x)
return x return x
@ -813,9 +822,14 @@ class UNetModel(nn.Module):
ctrl = control['output'].pop() ctrl = control['output'].pop()
if ctrl is not None: if ctrl is not None:
hsp += ctrl hsp += ctrl
h = th.cat([h, hsp], dim=1) h = th.cat([h, hsp], dim=1)
del hsp del hsp
h = module(h, emb, context, transformer_options) if len(hs) > 0:
output_shape = hs[-1].shape
else:
output_shape = None
h = module(h, emb, context, transformer_options, output_shape)
h = h.type(x.dtype) h = h.type(x.dtype)
if self.predict_codebook_ids: if self.predict_codebook_ids:
return self.id_predictor(h) return self.id_predictor(h)