diff --git a/comfy_extras/nodes_fresca.py b/comfy_extras/nodes_fresca.py index b0b86f23..fa573299 100644 --- a/comfy_extras/nodes_fresca.py +++ b/comfy_extras/nodes_fresca.py @@ -26,19 +26,17 @@ def Fourier_filter(x, scale_low=1.0, scale_high=1.5, freq_cutoff=20): x_freq = fft.fftn(x, dim=(-2, -1)) x_freq = fft.fftshift(x_freq, dim=(-2, -1)) - # 2) Create a mask to scale frequencies differently - B, C, H, W = x_freq.shape - crow, ccol = H // 2, W // 2 - # Initialize mask with high-frequency scaling factor - mask = torch.ones((B, C, H, W), device=device) * scale_high + mask = torch.ones(x_freq.shape, device=device) * scale_high + m = mask + for d in range(len(x_freq.shape) - 2): + dim = d + 2 + cc = x_freq.shape[dim] // 2 + f_c = min(freq_cutoff, cc) + m = m.narrow(dim, cc - f_c, f_c * 2) # Apply low-frequency scaling factor to center region - mask[ - ..., - crow - freq_cutoff : crow + freq_cutoff, - ccol - freq_cutoff : ccol + freq_cutoff, - ] = scale_low + m[:] = scale_low # 3) Apply frequency-specific scaling x_freq = x_freq * mask