From 805c36ac9c0611a483b1f494e2dfe6b67f09fe36 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 17 Feb 2024 00:42:22 -0500 Subject: [PATCH] Make Stable Cascade work on old pytorch 2.0 --- comfy/ldm/cascade/stage_a.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/comfy/ldm/cascade/stage_a.py b/comfy/ldm/cascade/stage_a.py index 55fdbf17..260ccfc0 100644 --- a/comfy/ldm/cascade/stage_a.py +++ b/comfy/ldm/cascade/stage_a.py @@ -150,7 +150,11 @@ class ResBlock(nn.Module): mods = self.gammas x_temp = self._norm(x, self.norm1) * (1 + mods[0]) + mods[1] - x = x + self.depthwise(x_temp) * mods[2] + try: + x = x + self.depthwise(x_temp) * mods[2] + except: #operation not implemented for bf16 + x_temp = self.depthwise[0](x_temp.float()).to(x.dtype) + x = x + self.depthwise[1](x_temp) * mods[2] x_temp = self._norm(x, self.norm2) * (1 + mods[3]) + mods[4] x = x + self.channelwise(x_temp.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) * mods[5]