mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Make unet work with any input shape.
This commit is contained in:
parent
a307c3f12c
commit
66c8aa5c3e
@ -76,12 +76,14 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
|
|||||||
support it as an extra input.
|
support it as an extra input.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def forward(self, x, emb, context=None, transformer_options={}):
|
def forward(self, x, emb, context=None, transformer_options={}, output_shape=None):
|
||||||
for layer in self:
|
for layer in self:
|
||||||
if isinstance(layer, TimestepBlock):
|
if isinstance(layer, TimestepBlock):
|
||||||
x = layer(x, emb)
|
x = layer(x, emb)
|
||||||
elif isinstance(layer, SpatialTransformer):
|
elif isinstance(layer, SpatialTransformer):
|
||||||
x = layer(x, context, transformer_options)
|
x = layer(x, context, transformer_options)
|
||||||
|
elif isinstance(layer, Upsample):
|
||||||
|
x = layer(x, output_shape=output_shape)
|
||||||
else:
|
else:
|
||||||
x = layer(x)
|
x = layer(x)
|
||||||
return x
|
return x
|
||||||
@ -105,14 +107,21 @@ class Upsample(nn.Module):
|
|||||||
if use_conv:
|
if use_conv:
|
||||||
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
|
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x, output_shape=None):
|
||||||
|
print("upsample", output_shape)
|
||||||
assert x.shape[1] == self.channels
|
assert x.shape[1] == self.channels
|
||||||
if self.dims == 3:
|
if self.dims == 3:
|
||||||
x = F.interpolate(
|
shape = [x.shape[2], x.shape[3] * 2, x.shape[4] * 2]
|
||||||
x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
|
if output_shape is not None:
|
||||||
)
|
shape[1] = output_shape[3]
|
||||||
|
shape[2] = output_shape[4]
|
||||||
else:
|
else:
|
||||||
x = F.interpolate(x, scale_factor=2, mode="nearest")
|
shape = [x.shape[2] * 2, x.shape[3] * 2]
|
||||||
|
if output_shape is not None:
|
||||||
|
shape[0] = output_shape[2]
|
||||||
|
shape[1] = output_shape[3]
|
||||||
|
|
||||||
|
x = F.interpolate(x, size=shape, mode="nearest")
|
||||||
if self.use_conv:
|
if self.use_conv:
|
||||||
x = self.conv(x)
|
x = self.conv(x)
|
||||||
return x
|
return x
|
||||||
@ -813,9 +822,14 @@ class UNetModel(nn.Module):
|
|||||||
ctrl = control['output'].pop()
|
ctrl = control['output'].pop()
|
||||||
if ctrl is not None:
|
if ctrl is not None:
|
||||||
hsp += ctrl
|
hsp += ctrl
|
||||||
|
|
||||||
h = th.cat([h, hsp], dim=1)
|
h = th.cat([h, hsp], dim=1)
|
||||||
del hsp
|
del hsp
|
||||||
h = module(h, emb, context, transformer_options)
|
if len(hs) > 0:
|
||||||
|
output_shape = hs[-1].shape
|
||||||
|
else:
|
||||||
|
output_shape = None
|
||||||
|
h = module(h, emb, context, transformer_options, output_shape)
|
||||||
h = h.type(x.dtype)
|
h = h.type(x.dtype)
|
||||||
if self.predict_codebook_ids:
|
if self.predict_codebook_ids:
|
||||||
return self.id_predictor(h)
|
return self.id_predictor(h)
|
||||||
|
Loading…
Reference in New Issue
Block a user