This commit is contained in:
comfyanonymous 2023-04-06 23:45:29 -04:00
commit 28a7205739
2 changed files with 38 additions and 13 deletions

View File

@ -18,10 +18,17 @@ total_vram = 0
total_vram_available_mb = -1 total_vram_available_mb = -1
accelerate_enabled = False accelerate_enabled = False
xpu_available = False
try: try:
import torch import torch
total_vram = torch.cuda.mem_get_info(torch.cuda.current_device())[1] / (1024 * 1024) try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
xpu_available = True
total_vram = torch.xpu.get_device_properties(torch.xpu.current_device()).total_memory / (1024 * 1024)
except:
total_vram = torch.cuda.mem_get_info(torch.cuda.current_device())[1] / (1024 * 1024)
total_ram = psutil.virtual_memory().total / (1024 * 1024) total_ram = psutil.virtual_memory().total / (1024 * 1024)
if not args.normalvram and not args.cpu: if not args.normalvram and not args.cpu:
if total_vram <= 4096: if total_vram <= 4096:
@ -122,6 +129,7 @@ def load_model_gpu(model):
global current_loaded_model global current_loaded_model
global vram_state global vram_state
global model_accelerated global model_accelerated
global xpu_available
if model is current_loaded_model: if model is current_loaded_model:
return return
@ -140,14 +148,17 @@ def load_model_gpu(model):
pass pass
elif vram_state == VRAMState.NORMAL_VRAM or vram_state == VRAMState.HIGH_VRAM: elif vram_state == VRAMState.NORMAL_VRAM or vram_state == VRAMState.HIGH_VRAM:
model_accelerated = False model_accelerated = False
real_model.cuda() if xpu_available:
real_model.to("xpu")
else:
real_model.cuda()
else: else:
if vram_state == VRAMState.NO_VRAM: if vram_state == VRAMState.NO_VRAM:
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"}) device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"})
elif vram_state == VRAMState.LOW_VRAM: elif vram_state == VRAMState.LOW_VRAM:
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(total_vram_available_mb), "cpu": "16GiB"}) device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(total_vram_available_mb), "cpu": "16GiB"})
accelerate.dispatch_model(real_model, device_map=device_map, main_device="cuda") accelerate.dispatch_model(real_model, device_map=device_map, main_device="xpu" if xpu_available else "cuda")
model_accelerated = True model_accelerated = True
return current_loaded_model return current_loaded_model
@ -173,8 +184,12 @@ def load_controlnet_gpu(models):
def load_if_low_vram(model): def load_if_low_vram(model):
global vram_state global vram_state
global xpu_available
if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM: if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM:
return model.cuda() if xpu_available:
return model.to("xpu")
else:
return model.cuda()
return model return model
def unload_if_low_vram(model): def unload_if_low_vram(model):
@ -184,12 +199,16 @@ def unload_if_low_vram(model):
return model return model
def get_torch_device(): def get_torch_device():
global xpu_available
if vram_state == VRAMState.MPS: if vram_state == VRAMState.MPS:
return torch.device("mps") return torch.device("mps")
if vram_state == VRAMState.CPU: if vram_state == VRAMState.CPU:
return torch.device("cpu") return torch.device("cpu")
else: else:
return torch.cuda.current_device() if xpu_available:
return torch.device("xpu")
else:
return torch.cuda.current_device()
def get_autocast_device(dev): def get_autocast_device(dev):
if hasattr(dev, 'type'): if hasattr(dev, 'type'):
@ -219,6 +238,7 @@ def pytorch_attention_enabled():
return ENABLE_PYTORCH_ATTENTION return ENABLE_PYTORCH_ATTENTION
def get_free_memory(dev=None, torch_free_too=False): def get_free_memory(dev=None, torch_free_too=False):
global xpu_available
if dev is None: if dev is None:
dev = get_torch_device() dev = get_torch_device()
@ -226,12 +246,16 @@ def get_free_memory(dev=None, torch_free_too=False):
mem_free_total = psutil.virtual_memory().available mem_free_total = psutil.virtual_memory().available
mem_free_torch = mem_free_total mem_free_torch = mem_free_total
else: else:
stats = torch.cuda.memory_stats(dev) if xpu_available:
mem_active = stats['active_bytes.all.current'] mem_free_total = torch.xpu.get_device_properties(dev).total_memory - torch.xpu.memory_allocated(dev)
mem_reserved = stats['reserved_bytes.all.current'] mem_free_torch = mem_free_total
mem_free_cuda, _ = torch.cuda.mem_get_info(dev) else:
mem_free_torch = mem_reserved - mem_active stats = torch.cuda.memory_stats(dev)
mem_free_total = mem_free_cuda + mem_free_torch mem_active = stats['active_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']
mem_free_cuda, _ = torch.cuda.mem_get_info(dev)
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_cuda + mem_free_torch
if torch_free_too: if torch_free_too:
return (mem_free_total, mem_free_torch) return (mem_free_total, mem_free_torch)
@ -256,7 +280,8 @@ def mps_mode():
return vram_state == VRAMState.MPS return vram_state == VRAMState.MPS
def should_use_fp16(): def should_use_fp16():
if cpu_mode() or mps_mode(): global xpu_available
if cpu_mode() or mps_mode() or xpu_available:
return False #TODO ? return False #TODO ?
if torch.cuda.is_bf16_supported(): if torch.cuda.is_bf16_supported():

View File

@ -4,7 +4,7 @@ torchsde
einops einops
open-clip-torch open-clip-torch
transformers>=4.25.1 transformers>=4.25.1
safetensors safetensors>=0.3.0
pytorch_lightning pytorch_lightning
aiohttp aiohttp
accelerate accelerate