From 9c1ed58ef218c8a20bd862de329d7d350d69a34d Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 10 Nov 2024 00:10:45 -0500 Subject: [PATCH] proper fix for sag. --- comfy_extras/nodes_sag.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/comfy_extras/nodes_sag.py b/comfy_extras/nodes_sag.py index 3f03533e..1bd8d736 100644 --- a/comfy_extras/nodes_sag.py +++ b/comfy_extras/nodes_sag.py @@ -58,16 +58,23 @@ def create_blur_map(x0, attn, sigma=3.0, threshold=1.0): # Global Average Pool mask = attn.mean(1, keepdim=False).sum(1, keepdim=False) > threshold - f = float(lh) / float(lw) - fh = f ** 0.5 - fw = (1/f) ** 0.5 - S = mask.size(1) ** 0.5 - w = int(0.5 + S * fw) - h = int(0.5 + S * fh) + total = mask.shape[-1] + x = round(math.sqrt((lh / lw) * total)) + xx = None + for i in range(0, math.floor(math.sqrt(total) / 2)): + for j in [(x + i), max(1, x - i)]: + if total % j == 0: + xx = j + break + if xx is not None: + break + + x = xx + y = total // x # Reshape mask = ( - mask.reshape(b, h, w) + mask.reshape(b, x, y) .unsqueeze(1) .type(attn.dtype) )