mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-20 03:13:30 +00:00
Cleanups and fixes for model_management.py
Hopefully fix regression on MPS and CPU.
This commit is contained in:
parent
700491d81a
commit
0a5fefd621
@ -4,16 +4,22 @@ from comfy.cli_args import args
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
class VRAMState(Enum):
|
class VRAMState(Enum):
|
||||||
CPU = 0
|
DISABLED = 0
|
||||||
NO_VRAM = 1
|
NO_VRAM = 1
|
||||||
LOW_VRAM = 2
|
LOW_VRAM = 2
|
||||||
NORMAL_VRAM = 3
|
NORMAL_VRAM = 3
|
||||||
HIGH_VRAM = 4
|
HIGH_VRAM = 4
|
||||||
MPS = 5
|
SHARED = 5
|
||||||
|
|
||||||
|
class CPUState(Enum):
|
||||||
|
GPU = 0
|
||||||
|
CPU = 1
|
||||||
|
MPS = 2
|
||||||
|
|
||||||
# Determine VRAM State
|
# Determine VRAM State
|
||||||
vram_state = VRAMState.NORMAL_VRAM
|
vram_state = VRAMState.NORMAL_VRAM
|
||||||
set_vram_to = VRAMState.NORMAL_VRAM
|
set_vram_to = VRAMState.NORMAL_VRAM
|
||||||
|
cpu_state = CPUState.GPU
|
||||||
|
|
||||||
total_vram = 0
|
total_vram = 0
|
||||||
|
|
||||||
@ -40,15 +46,25 @@ try:
|
|||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
if torch.backends.mps.is_available():
|
||||||
|
cpu_state = CPUState.MPS
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if args.cpu:
|
||||||
|
cpu_state = CPUState.CPU
|
||||||
|
|
||||||
def get_torch_device():
|
def get_torch_device():
|
||||||
global xpu_available
|
global xpu_available
|
||||||
global directml_enabled
|
global directml_enabled
|
||||||
|
global cpu_state
|
||||||
if directml_enabled:
|
if directml_enabled:
|
||||||
global directml_device
|
global directml_device
|
||||||
return directml_device
|
return directml_device
|
||||||
if vram_state == VRAMState.MPS:
|
if cpu_state == CPUState.MPS:
|
||||||
return torch.device("mps")
|
return torch.device("mps")
|
||||||
if vram_state == VRAMState.CPU:
|
if cpu_state == CPUState.CPU:
|
||||||
return torch.device("cpu")
|
return torch.device("cpu")
|
||||||
else:
|
else:
|
||||||
if xpu_available:
|
if xpu_available:
|
||||||
@ -143,8 +159,6 @@ if args.force_fp32:
|
|||||||
print("Forcing FP32, if this improves things please report it.")
|
print("Forcing FP32, if this improves things please report it.")
|
||||||
FORCE_FP32 = True
|
FORCE_FP32 = True
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if lowvram_available:
|
if lowvram_available:
|
||||||
try:
|
try:
|
||||||
import accelerate
|
import accelerate
|
||||||
@ -157,17 +171,15 @@ if lowvram_available:
|
|||||||
lowvram_available = False
|
lowvram_available = False
|
||||||
|
|
||||||
|
|
||||||
try:
|
if cpu_state != CPUState.GPU:
|
||||||
if torch.backends.mps.is_available():
|
vram_state = VRAMState.DISABLED
|
||||||
vram_state = VRAMState.MPS
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
if args.cpu:
|
if cpu_state == CPUState.MPS:
|
||||||
vram_state = VRAMState.CPU
|
vram_state = VRAMState.SHARED
|
||||||
|
|
||||||
print(f"Set vram state to: {vram_state.name}")
|
print(f"Set vram state to: {vram_state.name}")
|
||||||
|
|
||||||
|
|
||||||
def get_torch_device_name(device):
|
def get_torch_device_name(device):
|
||||||
if hasattr(device, 'type'):
|
if hasattr(device, 'type'):
|
||||||
if device.type == "cuda":
|
if device.type == "cuda":
|
||||||
@ -241,13 +253,9 @@ def load_model_gpu(model):
|
|||||||
|
|
||||||
current_loaded_model = model
|
current_loaded_model = model
|
||||||
|
|
||||||
if vram_set_state == VRAMState.CPU:
|
if vram_set_state == VRAMState.DISABLED:
|
||||||
pass
|
pass
|
||||||
elif vram_set_state == VRAMState.MPS:
|
elif vram_set_state == VRAMState.NORMAL_VRAM or vram_set_state == VRAMState.HIGH_VRAM or vram_set_state == VRAMState.SHARED:
|
||||||
mps_device = torch.device("mps")
|
|
||||||
real_model.to(mps_device)
|
|
||||||
pass
|
|
||||||
elif vram_set_state == VRAMState.NORMAL_VRAM or vram_set_state == VRAMState.HIGH_VRAM:
|
|
||||||
model_accelerated = False
|
model_accelerated = False
|
||||||
real_model.to(get_torch_device())
|
real_model.to(get_torch_device())
|
||||||
else:
|
else:
|
||||||
@ -263,7 +271,7 @@ def load_model_gpu(model):
|
|||||||
def load_controlnet_gpu(control_models):
|
def load_controlnet_gpu(control_models):
|
||||||
global current_gpu_controlnets
|
global current_gpu_controlnets
|
||||||
global vram_state
|
global vram_state
|
||||||
if vram_state == VRAMState.CPU:
|
if vram_state == VRAMState.DISABLED:
|
||||||
return
|
return
|
||||||
|
|
||||||
if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM:
|
if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM:
|
||||||
@ -308,7 +316,8 @@ def get_autocast_device(dev):
|
|||||||
def xformers_enabled():
|
def xformers_enabled():
|
||||||
global xpu_available
|
global xpu_available
|
||||||
global directml_enabled
|
global directml_enabled
|
||||||
if vram_state == VRAMState.CPU:
|
global cpu_state
|
||||||
|
if cpu_state != CPUState.GPU:
|
||||||
return False
|
return False
|
||||||
if xpu_available:
|
if xpu_available:
|
||||||
return False
|
return False
|
||||||
@ -380,12 +389,12 @@ def maximum_batch_area():
|
|||||||
return int(max(area, 0))
|
return int(max(area, 0))
|
||||||
|
|
||||||
def cpu_mode():
|
def cpu_mode():
|
||||||
global vram_state
|
global cpu_state
|
||||||
return vram_state == VRAMState.CPU
|
return cpu_state == CPUState.CPU
|
||||||
|
|
||||||
def mps_mode():
|
def mps_mode():
|
||||||
global vram_state
|
global cpu_state
|
||||||
return vram_state == VRAMState.MPS
|
return cpu_state == CPUState.MPS
|
||||||
|
|
||||||
def should_use_fp16():
|
def should_use_fp16():
|
||||||
global xpu_available
|
global xpu_available
|
||||||
@ -417,8 +426,8 @@ def should_use_fp16():
|
|||||||
|
|
||||||
def soft_empty_cache():
|
def soft_empty_cache():
|
||||||
global xpu_available
|
global xpu_available
|
||||||
global vram_state
|
global cpu_state
|
||||||
if vram_state == VRAMState.MPS:
|
if cpu_state == CPUState.MPS:
|
||||||
torch.mps.empty_cache()
|
torch.mps.empty_cache()
|
||||||
elif xpu_available:
|
elif xpu_available:
|
||||||
torch.xpu.empty_cache()
|
torch.xpu.empty_cache()
|
||||||
|
Loading…
Reference in New Issue
Block a user