mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 02:15:17 +00:00
Do FreeU fft on CPU if the device doesn't support fft functions.
This commit is contained in:
parent
77c124c5a1
commit
f00471cdc8
@ -39,11 +39,22 @@ class FreeU:
|
|||||||
def patch(self, model, b1, b2, s1, s2):
|
def patch(self, model, b1, b2, s1, s2):
|
||||||
model_channels = model.model.model_config.unet_config["model_channels"]
|
model_channels = model.model.model_config.unet_config["model_channels"]
|
||||||
scale_dict = {model_channels * 4: (b1, s1), model_channels * 2: (b2, s2)}
|
scale_dict = {model_channels * 4: (b1, s1), model_channels * 2: (b2, s2)}
|
||||||
|
on_cpu_devices = {}
|
||||||
|
|
||||||
def output_block_patch(h, hsp, transformer_options):
|
def output_block_patch(h, hsp, transformer_options):
|
||||||
scale = scale_dict.get(h.shape[1], None)
|
scale = scale_dict.get(h.shape[1], None)
|
||||||
if scale is not None:
|
if scale is not None:
|
||||||
h[:,:h.shape[1] // 2] = h[:,:h.shape[1] // 2] * scale[0]
|
h[:,:h.shape[1] // 2] = h[:,:h.shape[1] // 2] * scale[0]
|
||||||
|
if hsp.device not in on_cpu_devices:
|
||||||
|
try:
|
||||||
hsp = Fourier_filter(hsp, threshold=1, scale=scale[1])
|
hsp = Fourier_filter(hsp, threshold=1, scale=scale[1])
|
||||||
|
except:
|
||||||
|
print("Device", hsp.device, "does not support the torch.fft functions used in the FreeU node, switching to CPU.")
|
||||||
|
on_cpu_devices[hsp.device] = True
|
||||||
|
hsp = Fourier_filter(hsp.cpu(), threshold=1, scale=scale[1]).to(hsp.device)
|
||||||
|
else:
|
||||||
|
hsp = Fourier_filter(hsp.cpu(), threshold=1, scale=scale[1]).to(hsp.device)
|
||||||
|
|
||||||
return h, hsp
|
return h, hsp
|
||||||
|
|
||||||
m = model.clone()
|
m = model.clone()
|
||||||
|
Loading…
Reference in New Issue
Block a user