Use inplace operations for less OOM issues.

This commit is contained in:
comfyanonymous 2023-02-08 20:52:02 -05:00
parent 3fd87cbd21
commit 5b4e312749

View File

@ -96,6 +96,7 @@ class ResnetBlock(nn.Module):
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.swish = torch.nn.SiLU(inplace=True)
self.norm1 = Normalize(in_channels)
self.conv1 = torch.nn.Conv2d(in_channels,
out_channels,
@ -106,7 +107,7 @@ class ResnetBlock(nn.Module):
self.temb_proj = torch.nn.Linear(temb_channels,
out_channels)
self.norm2 = Normalize(out_channels)
self.dropout = torch.nn.Dropout(dropout)
self.dropout = torch.nn.Dropout(dropout, inplace=True)
self.conv2 = torch.nn.Conv2d(out_channels,
out_channels,
kernel_size=3,
@ -129,14 +130,14 @@ class ResnetBlock(nn.Module):
def forward(self, x, temb):
h = x
h = self.norm1(h)
h = nonlinearity(h)
h = self.swish(h)
h = self.conv1(h)
if temb is not None:
h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
h = h + self.temb_proj(self.swish(temb))[:,:,None,None]
h = self.norm2(h)
h = nonlinearity(h)
h = self.swish(h)
h = self.dropout(h)
h = self.conv2(h)