Make fresca work on multi dim.

This commit is contained in:
comfyanonymous 2025-04-17 15:46:41 -04:00
parent 19373aee75
commit 3dc240d089

View File

@ -26,19 +26,17 @@ def Fourier_filter(x, scale_low=1.0, scale_high=1.5, freq_cutoff=20):
x_freq = fft.fftn(x, dim=(-2, -1)) x_freq = fft.fftn(x, dim=(-2, -1))
x_freq = fft.fftshift(x_freq, dim=(-2, -1)) x_freq = fft.fftshift(x_freq, dim=(-2, -1))
# 2) Create a mask to scale frequencies differently
B, C, H, W = x_freq.shape
crow, ccol = H // 2, W // 2
# Initialize mask with high-frequency scaling factor # Initialize mask with high-frequency scaling factor
mask = torch.ones((B, C, H, W), device=device) * scale_high mask = torch.ones(x_freq.shape, device=device) * scale_high
m = mask
for d in range(len(x_freq.shape) - 2):
dim = d + 2
cc = x_freq.shape[dim] // 2
f_c = min(freq_cutoff, cc)
m = m.narrow(dim, cc - f_c, f_c * 2)
# Apply low-frequency scaling factor to center region # Apply low-frequency scaling factor to center region
mask[ m[:] = scale_low
...,
crow - freq_cutoff : crow + freq_cutoff,
ccol - freq_cutoff : ccol + freq_cutoff,
] = scale_low
# 3) Apply frequency-specific scaling # 3) Apply frequency-specific scaling
x_freq = x_freq * mask x_freq = x_freq * mask