diff --git a/comfy/model_management.py b/comfy/model_management.py index 7dda073d..92c59efe 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -18,10 +18,17 @@ total_vram = 0 total_vram_available_mb = -1 accelerate_enabled = False +xpu_available = False try: import torch - total_vram = torch.cuda.mem_get_info(torch.cuda.current_device())[1] / (1024 * 1024) + try: + import intel_extension_for_pytorch as ipex + if torch.xpu.is_available(): + xpu_available = True + total_vram = torch.xpu.get_device_properties(torch.xpu.current_device()).total_memory / (1024 * 1024) + except: + total_vram = torch.cuda.mem_get_info(torch.cuda.current_device())[1] / (1024 * 1024) total_ram = psutil.virtual_memory().total / (1024 * 1024) if not args.normalvram and not args.cpu: if total_vram <= 4096: @@ -122,6 +129,7 @@ def load_model_gpu(model): global current_loaded_model global vram_state global model_accelerated + global xpu_available if model is current_loaded_model: return @@ -140,14 +148,17 @@ def load_model_gpu(model): pass elif vram_state == VRAMState.NORMAL_VRAM or vram_state == VRAMState.HIGH_VRAM: model_accelerated = False - real_model.cuda() + if xpu_available: + real_model.to("xpu") + else: + real_model.cuda() else: if vram_state == VRAMState.NO_VRAM: device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"}) elif vram_state == VRAMState.LOW_VRAM: device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(total_vram_available_mb), "cpu": "16GiB"}) - accelerate.dispatch_model(real_model, device_map=device_map, main_device="cuda") + accelerate.dispatch_model(real_model, device_map=device_map, main_device="xpu" if xpu_available else "cuda") model_accelerated = True return current_loaded_model @@ -173,8 +184,12 @@ def load_controlnet_gpu(models): def load_if_low_vram(model): global vram_state + global xpu_available if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM: - return model.cuda() + if xpu_available: + return model.to("xpu") + else: + return model.cuda() return model def unload_if_low_vram(model): @@ -184,12 +199,16 @@ def unload_if_low_vram(model): return model def get_torch_device(): + global xpu_available if vram_state == VRAMState.MPS: return torch.device("mps") if vram_state == VRAMState.CPU: return torch.device("cpu") else: - return torch.cuda.current_device() + if xpu_available: + return torch.device("xpu") + else: + return torch.cuda.current_device() def get_autocast_device(dev): if hasattr(dev, 'type'): @@ -219,6 +238,7 @@ def pytorch_attention_enabled(): return ENABLE_PYTORCH_ATTENTION def get_free_memory(dev=None, torch_free_too=False): + global xpu_available if dev is None: dev = get_torch_device() @@ -226,12 +246,16 @@ def get_free_memory(dev=None, torch_free_too=False): mem_free_total = psutil.virtual_memory().available mem_free_torch = mem_free_total else: - stats = torch.cuda.memory_stats(dev) - mem_active = stats['active_bytes.all.current'] - mem_reserved = stats['reserved_bytes.all.current'] - mem_free_cuda, _ = torch.cuda.mem_get_info(dev) - mem_free_torch = mem_reserved - mem_active - mem_free_total = mem_free_cuda + mem_free_torch + if xpu_available: + mem_free_total = torch.xpu.get_device_properties(dev).total_memory - torch.xpu.memory_allocated(dev) + mem_free_torch = mem_free_total + else: + stats = torch.cuda.memory_stats(dev) + mem_active = stats['active_bytes.all.current'] + mem_reserved = stats['reserved_bytes.all.current'] + mem_free_cuda, _ = torch.cuda.mem_get_info(dev) + mem_free_torch = mem_reserved - mem_active + mem_free_total = mem_free_cuda + mem_free_torch if torch_free_too: return (mem_free_total, mem_free_torch) @@ -256,7 +280,8 @@ def mps_mode(): return vram_state == VRAMState.MPS def should_use_fp16(): - if cpu_mode() or mps_mode(): + global xpu_available + if cpu_mode() or mps_mode() or xpu_available: return False #TODO ? if torch.cuda.is_bf16_supported(): diff --git a/requirements.txt b/requirements.txt index 3b4040a2..0527b31d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,7 +4,7 @@ torchsde einops open-clip-torch transformers>=4.25.1 -safetensors +safetensors>=0.3.0 pytorch_lightning aiohttp accelerate