Make --bf16-vae work on torch 2.0

This commit is contained in:
comfyanonymous 2023-08-27 21:33:53 -04:00
parent 412596d325
commit d935ba50c4

View File

@ -56,7 +56,18 @@ class Upsample(nn.Module):
padding=1)
def forward(self, x):
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
try:
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
except: #operation not implemented for bf16
b, c, h, w = x.shape
out = torch.empty((b, c, h*2, w*2), dtype=x.dtype, layout=x.layout, device=x.device)
split = 8
l = out.shape[1] // split
for i in range(0, out.shape[1], l):
out[:,i:i+l] = torch.nn.functional.interpolate(x[:,i:i+l].to(torch.float32), scale_factor=2.0, mode="nearest").to(x.dtype)
del x
x = out
if self.with_conv:
x = self.conv(x)
return x