mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-20 03:13:30 +00:00
Use inplace operations for less OOM issues.
This commit is contained in:
parent
3fd87cbd21
commit
5b4e312749
@ -96,6 +96,7 @@ class ResnetBlock(nn.Module):
|
|||||||
self.out_channels = out_channels
|
self.out_channels = out_channels
|
||||||
self.use_conv_shortcut = conv_shortcut
|
self.use_conv_shortcut = conv_shortcut
|
||||||
|
|
||||||
|
self.swish = torch.nn.SiLU(inplace=True)
|
||||||
self.norm1 = Normalize(in_channels)
|
self.norm1 = Normalize(in_channels)
|
||||||
self.conv1 = torch.nn.Conv2d(in_channels,
|
self.conv1 = torch.nn.Conv2d(in_channels,
|
||||||
out_channels,
|
out_channels,
|
||||||
@ -106,7 +107,7 @@ class ResnetBlock(nn.Module):
|
|||||||
self.temb_proj = torch.nn.Linear(temb_channels,
|
self.temb_proj = torch.nn.Linear(temb_channels,
|
||||||
out_channels)
|
out_channels)
|
||||||
self.norm2 = Normalize(out_channels)
|
self.norm2 = Normalize(out_channels)
|
||||||
self.dropout = torch.nn.Dropout(dropout)
|
self.dropout = torch.nn.Dropout(dropout, inplace=True)
|
||||||
self.conv2 = torch.nn.Conv2d(out_channels,
|
self.conv2 = torch.nn.Conv2d(out_channels,
|
||||||
out_channels,
|
out_channels,
|
||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
@ -129,14 +130,14 @@ class ResnetBlock(nn.Module):
|
|||||||
def forward(self, x, temb):
|
def forward(self, x, temb):
|
||||||
h = x
|
h = x
|
||||||
h = self.norm1(h)
|
h = self.norm1(h)
|
||||||
h = nonlinearity(h)
|
h = self.swish(h)
|
||||||
h = self.conv1(h)
|
h = self.conv1(h)
|
||||||
|
|
||||||
if temb is not None:
|
if temb is not None:
|
||||||
h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
|
h = h + self.temb_proj(self.swish(temb))[:,:,None,None]
|
||||||
|
|
||||||
h = self.norm2(h)
|
h = self.norm2(h)
|
||||||
h = nonlinearity(h)
|
h = self.swish(h)
|
||||||
h = self.dropout(h)
|
h = self.dropout(h)
|
||||||
h = self.conv2(h)
|
h = self.conv2(h)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user