Try to fix memory issue.

This commit is contained in:
comfyanonymous 2023-03-11 15:15:13 -05:00
parent 2b1fce2943
commit 1de86851b1

View File

@ -82,8 +82,9 @@ class Downsample(nn.Module):
stride=2, stride=2,
padding=0) padding=0)
def forward(self, x): def forward(self, x, already_padded=False):
if self.with_conv: if self.with_conv:
if not already_padded:
pad = (0,1,0,1) pad = (0,1,0,1)
x = torch.nn.functional.pad(x, pad, mode="constant", value=0) x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
x = self.conv(x) x = self.conv(x)
@ -553,7 +554,9 @@ class Encoder(nn.Module):
def forward(self, x): def forward(self, x):
# timestep embedding # timestep embedding
temb = None temb = None
pad = (0,1,0,1)
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
already_padded = True
# downsampling # downsampling
hs = [self.conv_in(x)] hs = [self.conv_in(x)]
for i_level in range(self.num_resolutions): for i_level in range(self.num_resolutions):
@ -563,7 +566,8 @@ class Encoder(nn.Module):
h = self.down[i_level].attn[i_block](h) h = self.down[i_level].attn[i_block](h)
hs.append(h) hs.append(h)
if i_level != self.num_resolutions-1: if i_level != self.num_resolutions-1:
hs.append(self.down[i_level].downsample(hs[-1])) hs.append(self.down[i_level].downsample(hs[-1], already_padded))
already_padded = False
# middle # middle
h = hs[-1] h = hs[-1]