mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-20 03:13:30 +00:00
Use separate variables instead of vram_state
This commit is contained in:
parent
84b9c0ac2f
commit
7cb924f684
@ -5,9 +5,9 @@ LOW_VRAM = 2
|
|||||||
NORMAL_VRAM = 3
|
NORMAL_VRAM = 3
|
||||||
HIGH_VRAM = 4
|
HIGH_VRAM = 4
|
||||||
MPS = 5
|
MPS = 5
|
||||||
XPU = 6
|
|
||||||
|
|
||||||
accelerate_enabled = False
|
accelerate_enabled = False
|
||||||
|
xpu_available = False
|
||||||
vram_state = NORMAL_VRAM
|
vram_state = NORMAL_VRAM
|
||||||
|
|
||||||
total_vram = 0
|
total_vram = 0
|
||||||
@ -22,6 +22,11 @@ set_vram_to = NORMAL_VRAM
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
import torch
|
import torch
|
||||||
|
import intel_extension_for_pytorch as ipex
|
||||||
|
if torch.xpu.is_available():
|
||||||
|
xpu_available = True
|
||||||
|
total_vram = torch.xpu.get_device_properties(torch.xpu.current_device()).total_memory / (1024 * 1024)
|
||||||
|
else:
|
||||||
total_vram = torch.cuda.mem_get_info(torch.cuda.current_device())[1] / (1024 * 1024)
|
total_vram = torch.cuda.mem_get_info(torch.cuda.current_device())[1] / (1024 * 1024)
|
||||||
total_ram = psutil.virtual_memory().total / (1024 * 1024)
|
total_ram = psutil.virtual_memory().total / (1024 * 1024)
|
||||||
forced_normal_vram = "--normalvram" in sys.argv
|
forced_normal_vram = "--normalvram" in sys.argv
|
||||||
@ -86,17 +91,10 @@ try:
|
|||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
try:
|
|
||||||
import intel_extension_for_pytorch as ipex
|
|
||||||
if torch.xpu.is_available():
|
|
||||||
vram_state = XPU
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
if forced_cpu:
|
if forced_cpu:
|
||||||
vram_state = CPU
|
vram_state = CPU
|
||||||
|
|
||||||
print("Set vram state to:", ["CPU", "NO VRAM", "LOW VRAM", "NORMAL VRAM", "HIGH VRAM", "MPS", "XPU"][vram_state])
|
print("Set vram state to:", ["CPU", "NO VRAM", "LOW VRAM", "NORMAL VRAM", "HIGH VRAM", "MPS"][vram_state])
|
||||||
|
|
||||||
|
|
||||||
current_loaded_model = None
|
current_loaded_model = None
|
||||||
@ -133,6 +131,7 @@ def load_model_gpu(model):
|
|||||||
global current_loaded_model
|
global current_loaded_model
|
||||||
global vram_state
|
global vram_state
|
||||||
global model_accelerated
|
global model_accelerated
|
||||||
|
global xpu_available
|
||||||
|
|
||||||
if model is current_loaded_model:
|
if model is current_loaded_model:
|
||||||
return
|
return
|
||||||
@ -149,11 +148,11 @@ def load_model_gpu(model):
|
|||||||
mps_device = torch.device("mps")
|
mps_device = torch.device("mps")
|
||||||
real_model.to(mps_device)
|
real_model.to(mps_device)
|
||||||
pass
|
pass
|
||||||
elif vram_state == XPU:
|
|
||||||
real_model.to("xpu")
|
|
||||||
pass
|
|
||||||
elif vram_state == NORMAL_VRAM or vram_state == HIGH_VRAM:
|
elif vram_state == NORMAL_VRAM or vram_state == HIGH_VRAM:
|
||||||
model_accelerated = False
|
model_accelerated = False
|
||||||
|
if xpu_available:
|
||||||
|
real_model.to("xpu")
|
||||||
|
else:
|
||||||
real_model.cuda()
|
real_model.cuda()
|
||||||
else:
|
else:
|
||||||
if vram_state == NO_VRAM:
|
if vram_state == NO_VRAM:
|
||||||
@ -161,7 +160,7 @@ def load_model_gpu(model):
|
|||||||
elif vram_state == LOW_VRAM:
|
elif vram_state == LOW_VRAM:
|
||||||
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(total_vram_available_mb), "cpu": "16GiB"})
|
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(total_vram_available_mb), "cpu": "16GiB"})
|
||||||
|
|
||||||
accelerate.dispatch_model(real_model, device_map=device_map, main_device="cuda")
|
accelerate.dispatch_model(real_model, device_map=device_map, main_device="xpu" if xpu_available else "cuda")
|
||||||
model_accelerated = True
|
model_accelerated = True
|
||||||
return current_loaded_model
|
return current_loaded_model
|
||||||
|
|
||||||
@ -187,7 +186,11 @@ def load_controlnet_gpu(models):
|
|||||||
|
|
||||||
def load_if_low_vram(model):
|
def load_if_low_vram(model):
|
||||||
global vram_state
|
global vram_state
|
||||||
|
global xpu_available
|
||||||
if vram_state == LOW_VRAM or vram_state == NO_VRAM:
|
if vram_state == LOW_VRAM or vram_state == NO_VRAM:
|
||||||
|
if xpu_available:
|
||||||
|
return model.to("xpu")
|
||||||
|
else:
|
||||||
return model.cuda()
|
return model.cuda()
|
||||||
return model
|
return model
|
||||||
|
|
||||||
@ -198,12 +201,14 @@ def unload_if_low_vram(model):
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
def get_torch_device():
|
def get_torch_device():
|
||||||
|
global xpu_available
|
||||||
if vram_state == MPS:
|
if vram_state == MPS:
|
||||||
return torch.device("mps")
|
return torch.device("mps")
|
||||||
if vram_state == XPU:
|
|
||||||
return torch.device("xpu")
|
|
||||||
if vram_state == CPU:
|
if vram_state == CPU:
|
||||||
return torch.device("cpu")
|
return torch.device("cpu")
|
||||||
|
else:
|
||||||
|
if xpu_available:
|
||||||
|
return torch.device("xpu")
|
||||||
else:
|
else:
|
||||||
return torch.cuda.current_device()
|
return torch.cuda.current_device()
|
||||||
|
|
||||||
@ -235,13 +240,15 @@ def pytorch_attention_enabled():
|
|||||||
return ENABLE_PYTORCH_ATTENTION
|
return ENABLE_PYTORCH_ATTENTION
|
||||||
|
|
||||||
def get_free_memory(dev=None, torch_free_too=False):
|
def get_free_memory(dev=None, torch_free_too=False):
|
||||||
|
global xpu_available
|
||||||
if dev is None:
|
if dev is None:
|
||||||
dev = get_torch_device()
|
dev = get_torch_device()
|
||||||
|
|
||||||
if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'):
|
if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'):
|
||||||
mem_free_total = psutil.virtual_memory().available
|
mem_free_total = psutil.virtual_memory().available
|
||||||
mem_free_torch = mem_free_total
|
mem_free_torch = mem_free_total
|
||||||
elif hasattr(dev, 'type') and (dev.type == 'xpu'):
|
else:
|
||||||
|
if xpu_available:
|
||||||
mem_free_total = torch.xpu.get_device_properties(dev).total_memory - torch.xpu.memory_allocated(dev)
|
mem_free_total = torch.xpu.get_device_properties(dev).total_memory - torch.xpu.memory_allocated(dev)
|
||||||
mem_free_torch = mem_free_total
|
mem_free_torch = mem_free_total
|
||||||
else:
|
else:
|
||||||
@ -274,12 +281,9 @@ def mps_mode():
|
|||||||
global vram_state
|
global vram_state
|
||||||
return vram_state == MPS
|
return vram_state == MPS
|
||||||
|
|
||||||
def xpu_mode():
|
|
||||||
global vram_state
|
|
||||||
return vram_state == XPU
|
|
||||||
|
|
||||||
def should_use_fp16():
|
def should_use_fp16():
|
||||||
if cpu_mode() or mps_mode() or xpu_mode():
|
global xpu_available
|
||||||
|
if cpu_mode() or mps_mode() or xpu_available:
|
||||||
return False #TODO ?
|
return False #TODO ?
|
||||||
|
|
||||||
if torch.cuda.is_bf16_supported():
|
if torch.cuda.is_bf16_supported():
|
||||||
|
Loading…
Reference in New Issue
Block a user