diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 374dd2f7..b4f22f31 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -58,6 +58,8 @@ fpvae_group.add_argument("--bf16-vae", action="store_true", help="Run the VAE in parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.") +parser.add_argument("--disable-ipex-optimize", action="store_true", help="Disables ipex.optimize when loading models with Intel GPUs.") + class LatentPreviewMethod(enum.Enum): NoPreviews = "none" Auto = "auto" diff --git a/comfy/model_management.py b/comfy/model_management.py index 80f6620a..83e2f6df 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -88,8 +88,10 @@ def get_total_memory(dev=None, torch_total_too=False): mem_total = 1024 * 1024 * 1024 #TODO mem_total_torch = mem_total elif xpu_available: + stats = torch.xpu.memory_stats(dev) + mem_reserved = stats['reserved_bytes.all.current'] mem_total = torch.xpu.get_device_properties(dev).total_memory - mem_total_torch = mem_total + mem_total_torch = mem_reserved else: stats = torch.cuda.memory_stats(dev) mem_reserved = stats['reserved_bytes.all.current'] @@ -208,6 +210,7 @@ if DISABLE_SMART_MEMORY: print("Disabling smart memory management") def get_torch_device_name(device): + global xpu_available if hasattr(device, 'type'): if device.type == "cuda": try: @@ -217,6 +220,8 @@ def get_torch_device_name(device): return "{} {} : {}".format(device, torch.cuda.get_device_name(device), allocator_backend) else: return "{}".format(device.type) + elif xpu_available: + return "{} {}".format(device, torch.xpu.get_device_name(device)) else: return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device)) @@ -244,6 +249,7 @@ class LoadedModel: return self.model_memory() def model_load(self, lowvram_model_memory=0): + global xpu_available patch_model_to = None if lowvram_model_memory == 0: patch_model_to = self.device @@ -264,6 +270,10 @@ class LoadedModel: accelerate.dispatch_model(self.real_model, device_map=device_map, main_device=self.device) self.model_accelerated = True + if xpu_available and not args.disable_ipex_optimize: + self.real_model.training = False + self.real_model = torch.xpu.optimize(self.real_model, inplace=True) + return self.real_model def model_unload(self): @@ -500,8 +510,12 @@ def get_free_memory(dev=None, torch_free_too=False): mem_free_total = 1024 * 1024 * 1024 #TODO mem_free_torch = mem_free_total elif xpu_available: - mem_free_total = torch.xpu.get_device_properties(dev).total_memory - torch.xpu.memory_allocated(dev) - mem_free_torch = mem_free_total + stats = torch.xpu.memory_stats(dev) + mem_active = stats['active_bytes.all.current'] + mem_allocated = stats['allocated_bytes.all.current'] + mem_reserved = stats['reserved_bytes.all.current'] + mem_free_torch = mem_reserved - mem_active + mem_free_total = torch.xpu.get_device_properties(dev).total_memory - mem_allocated + mem_free_torch else: stats = torch.cuda.memory_stats(dev) mem_active = stats['active_bytes.all.current'] @@ -573,10 +587,10 @@ def should_use_fp16(device=None, model_params=0): if directml_enabled: return False - if cpu_mode() or mps_mode() or xpu_available: + if cpu_mode() or mps_mode(): return False #TODO ? - if torch.cuda.is_bf16_supported(): + if torch.cuda.is_bf16_supported() or xpu_available: return True props = torch.cuda.get_device_properties("cuda")