Remove optimization that caused border.

This commit is contained in:
comfyanonymous 2023-08-29 11:20:17 -04:00
parent 8ddd081b09
commit bed116a1f9

View File

@ -85,11 +85,10 @@ class Downsample(nn.Module):
stride=2, stride=2,
padding=0) padding=0)
def forward(self, x, already_padded=False): def forward(self, x):
if self.with_conv: if self.with_conv:
if not already_padded: pad = (0,1,0,1)
pad = (0,1,0,1) x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
x = self.conv(x) x = self.conv(x)
else: else:
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
@ -606,9 +605,6 @@ class Encoder(nn.Module):
def forward(self, x): def forward(self, x):
# timestep embedding # timestep embedding
temb = None temb = None
pad = (0,1,0,1)
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
already_padded = True
# downsampling # downsampling
h = self.conv_in(x) h = self.conv_in(x)
for i_level in range(self.num_resolutions): for i_level in range(self.num_resolutions):
@ -617,8 +613,7 @@ class Encoder(nn.Module):
if len(self.down[i_level].attn) > 0: if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h) h = self.down[i_level].attn[i_block](h)
if i_level != self.num_resolutions-1: if i_level != self.num_resolutions-1:
h = self.down[i_level].downsample(h, already_padded) h = self.down[i_level].downsample(h)
already_padded = False
# middle # middle
h = self.mid.block_1(h, temb) h = self.mid.block_1(h, temb)