diff --git a/comfy_extras/nodes_freelunch.py b/comfy_extras/nodes_freelunch.py index c3542a7a..07a88bd9 100644 --- a/comfy_extras/nodes_freelunch.py +++ b/comfy_extras/nodes_freelunch.py @@ -39,11 +39,22 @@ class FreeU: def patch(self, model, b1, b2, s1, s2): model_channels = model.model.model_config.unet_config["model_channels"] scale_dict = {model_channels * 4: (b1, s1), model_channels * 2: (b2, s2)} + on_cpu_devices = {} + def output_block_patch(h, hsp, transformer_options): scale = scale_dict.get(h.shape[1], None) if scale is not None: h[:,:h.shape[1] // 2] = h[:,:h.shape[1] // 2] * scale[0] - hsp = Fourier_filter(hsp, threshold=1, scale=scale[1]) + if hsp.device not in on_cpu_devices: + try: + hsp = Fourier_filter(hsp, threshold=1, scale=scale[1]) + except: + print("Device", hsp.device, "does not support the torch.fft functions used in the FreeU node, switching to CPU.") + on_cpu_devices[hsp.device] = True + hsp = Fourier_filter(hsp.cpu(), threshold=1, scale=scale[1]).to(hsp.device) + else: + hsp = Fourier_filter(hsp.cpu(), threshold=1, scale=scale[1]).to(hsp.device) + return h, hsp m = model.clone()