mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-20 03:13:30 +00:00
Add ipex optimize and other enhancements for Intel GPUs based on recent memory changes.
This commit is contained in:
parent
8ee0473687
commit
2c096e4260
@ -58,6 +58,8 @@ fpvae_group.add_argument("--bf16-vae", action="store_true", help="Run the VAE in
|
|||||||
|
|
||||||
parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")
|
parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")
|
||||||
|
|
||||||
|
parser.add_argument("--disable-ipex-optimize", action="store_true", help="Disables ipex.optimize when loading models with Intel GPUs.")
|
||||||
|
|
||||||
class LatentPreviewMethod(enum.Enum):
|
class LatentPreviewMethod(enum.Enum):
|
||||||
NoPreviews = "none"
|
NoPreviews = "none"
|
||||||
Auto = "auto"
|
Auto = "auto"
|
||||||
|
@ -88,8 +88,10 @@ def get_total_memory(dev=None, torch_total_too=False):
|
|||||||
mem_total = 1024 * 1024 * 1024 #TODO
|
mem_total = 1024 * 1024 * 1024 #TODO
|
||||||
mem_total_torch = mem_total
|
mem_total_torch = mem_total
|
||||||
elif xpu_available:
|
elif xpu_available:
|
||||||
|
stats = torch.xpu.memory_stats(dev)
|
||||||
|
mem_reserved = stats['reserved_bytes.all.current']
|
||||||
mem_total = torch.xpu.get_device_properties(dev).total_memory
|
mem_total = torch.xpu.get_device_properties(dev).total_memory
|
||||||
mem_total_torch = mem_total
|
mem_total_torch = mem_reserved
|
||||||
else:
|
else:
|
||||||
stats = torch.cuda.memory_stats(dev)
|
stats = torch.cuda.memory_stats(dev)
|
||||||
mem_reserved = stats['reserved_bytes.all.current']
|
mem_reserved = stats['reserved_bytes.all.current']
|
||||||
@ -208,6 +210,7 @@ if DISABLE_SMART_MEMORY:
|
|||||||
print("Disabling smart memory management")
|
print("Disabling smart memory management")
|
||||||
|
|
||||||
def get_torch_device_name(device):
|
def get_torch_device_name(device):
|
||||||
|
global xpu_available
|
||||||
if hasattr(device, 'type'):
|
if hasattr(device, 'type'):
|
||||||
if device.type == "cuda":
|
if device.type == "cuda":
|
||||||
try:
|
try:
|
||||||
@ -217,6 +220,8 @@ def get_torch_device_name(device):
|
|||||||
return "{} {} : {}".format(device, torch.cuda.get_device_name(device), allocator_backend)
|
return "{} {} : {}".format(device, torch.cuda.get_device_name(device), allocator_backend)
|
||||||
else:
|
else:
|
||||||
return "{}".format(device.type)
|
return "{}".format(device.type)
|
||||||
|
elif xpu_available:
|
||||||
|
return "{} {}".format(device, torch.xpu.get_device_name(device))
|
||||||
else:
|
else:
|
||||||
return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device))
|
return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device))
|
||||||
|
|
||||||
@ -244,6 +249,7 @@ class LoadedModel:
|
|||||||
return self.model_memory()
|
return self.model_memory()
|
||||||
|
|
||||||
def model_load(self, lowvram_model_memory=0):
|
def model_load(self, lowvram_model_memory=0):
|
||||||
|
global xpu_available
|
||||||
patch_model_to = None
|
patch_model_to = None
|
||||||
if lowvram_model_memory == 0:
|
if lowvram_model_memory == 0:
|
||||||
patch_model_to = self.device
|
patch_model_to = self.device
|
||||||
@ -264,6 +270,10 @@ class LoadedModel:
|
|||||||
accelerate.dispatch_model(self.real_model, device_map=device_map, main_device=self.device)
|
accelerate.dispatch_model(self.real_model, device_map=device_map, main_device=self.device)
|
||||||
self.model_accelerated = True
|
self.model_accelerated = True
|
||||||
|
|
||||||
|
if xpu_available and not args.disable_ipex_optimize:
|
||||||
|
self.real_model.training = False
|
||||||
|
self.real_model = torch.xpu.optimize(self.real_model, inplace=True)
|
||||||
|
|
||||||
return self.real_model
|
return self.real_model
|
||||||
|
|
||||||
def model_unload(self):
|
def model_unload(self):
|
||||||
@ -500,8 +510,12 @@ def get_free_memory(dev=None, torch_free_too=False):
|
|||||||
mem_free_total = 1024 * 1024 * 1024 #TODO
|
mem_free_total = 1024 * 1024 * 1024 #TODO
|
||||||
mem_free_torch = mem_free_total
|
mem_free_torch = mem_free_total
|
||||||
elif xpu_available:
|
elif xpu_available:
|
||||||
mem_free_total = torch.xpu.get_device_properties(dev).total_memory - torch.xpu.memory_allocated(dev)
|
stats = torch.xpu.memory_stats(dev)
|
||||||
mem_free_torch = mem_free_total
|
mem_active = stats['active_bytes.all.current']
|
||||||
|
mem_allocated = stats['allocated_bytes.all.current']
|
||||||
|
mem_reserved = stats['reserved_bytes.all.current']
|
||||||
|
mem_free_torch = mem_reserved - mem_active
|
||||||
|
mem_free_total = torch.xpu.get_device_properties(dev).total_memory - mem_allocated + mem_free_torch
|
||||||
else:
|
else:
|
||||||
stats = torch.cuda.memory_stats(dev)
|
stats = torch.cuda.memory_stats(dev)
|
||||||
mem_active = stats['active_bytes.all.current']
|
mem_active = stats['active_bytes.all.current']
|
||||||
@ -573,10 +587,10 @@ def should_use_fp16(device=None, model_params=0):
|
|||||||
if directml_enabled:
|
if directml_enabled:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if cpu_mode() or mps_mode() or xpu_available:
|
if cpu_mode() or mps_mode():
|
||||||
return False #TODO ?
|
return False #TODO ?
|
||||||
|
|
||||||
if torch.cuda.is_bf16_supported():
|
if torch.cuda.is_bf16_supported() or xpu_available:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
props = torch.cuda.get_device_properties("cuda")
|
props = torch.cuda.get_device_properties("cuda")
|
||||||
|
Loading…
Reference in New Issue
Block a user