Cleanups and fixes for model_management.py

Hopefully fix regression on MPS and CPU.
This commit is contained in:
comfyanonymous 2023-06-03 11:05:37 -04:00
parent 700491d81a
commit 0a5fefd621

View File

@ -4,16 +4,22 @@ from comfy.cli_args import args
import torch import torch
class VRAMState(Enum): class VRAMState(Enum):
CPU = 0 DISABLED = 0
NO_VRAM = 1 NO_VRAM = 1
LOW_VRAM = 2 LOW_VRAM = 2
NORMAL_VRAM = 3 NORMAL_VRAM = 3
HIGH_VRAM = 4 HIGH_VRAM = 4
MPS = 5 SHARED = 5
class CPUState(Enum):
GPU = 0
CPU = 1
MPS = 2
# Determine VRAM State # Determine VRAM State
vram_state = VRAMState.NORMAL_VRAM vram_state = VRAMState.NORMAL_VRAM
set_vram_to = VRAMState.NORMAL_VRAM set_vram_to = VRAMState.NORMAL_VRAM
cpu_state = CPUState.GPU
total_vram = 0 total_vram = 0
@ -40,15 +46,25 @@ try:
except: except:
pass pass
try:
if torch.backends.mps.is_available():
cpu_state = CPUState.MPS
except:
pass
if args.cpu:
cpu_state = CPUState.CPU
def get_torch_device(): def get_torch_device():
global xpu_available global xpu_available
global directml_enabled global directml_enabled
global cpu_state
if directml_enabled: if directml_enabled:
global directml_device global directml_device
return directml_device return directml_device
if vram_state == VRAMState.MPS: if cpu_state == CPUState.MPS:
return torch.device("mps") return torch.device("mps")
if vram_state == VRAMState.CPU: if cpu_state == CPUState.CPU:
return torch.device("cpu") return torch.device("cpu")
else: else:
if xpu_available: if xpu_available:
@ -143,8 +159,6 @@ if args.force_fp32:
print("Forcing FP32, if this improves things please report it.") print("Forcing FP32, if this improves things please report it.")
FORCE_FP32 = True FORCE_FP32 = True
if lowvram_available: if lowvram_available:
try: try:
import accelerate import accelerate
@ -157,17 +171,15 @@ if lowvram_available:
lowvram_available = False lowvram_available = False
try: if cpu_state != CPUState.GPU:
if torch.backends.mps.is_available(): vram_state = VRAMState.DISABLED
vram_state = VRAMState.MPS
except:
pass
if args.cpu: if cpu_state == CPUState.MPS:
vram_state = VRAMState.CPU vram_state = VRAMState.SHARED
print(f"Set vram state to: {vram_state.name}") print(f"Set vram state to: {vram_state.name}")
def get_torch_device_name(device): def get_torch_device_name(device):
if hasattr(device, 'type'): if hasattr(device, 'type'):
if device.type == "cuda": if device.type == "cuda":
@ -241,13 +253,9 @@ def load_model_gpu(model):
current_loaded_model = model current_loaded_model = model
if vram_set_state == VRAMState.CPU: if vram_set_state == VRAMState.DISABLED:
pass pass
elif vram_set_state == VRAMState.MPS: elif vram_set_state == VRAMState.NORMAL_VRAM or vram_set_state == VRAMState.HIGH_VRAM or vram_set_state == VRAMState.SHARED:
mps_device = torch.device("mps")
real_model.to(mps_device)
pass
elif vram_set_state == VRAMState.NORMAL_VRAM or vram_set_state == VRAMState.HIGH_VRAM:
model_accelerated = False model_accelerated = False
real_model.to(get_torch_device()) real_model.to(get_torch_device())
else: else:
@ -263,7 +271,7 @@ def load_model_gpu(model):
def load_controlnet_gpu(control_models): def load_controlnet_gpu(control_models):
global current_gpu_controlnets global current_gpu_controlnets
global vram_state global vram_state
if vram_state == VRAMState.CPU: if vram_state == VRAMState.DISABLED:
return return
if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM: if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM:
@ -308,7 +316,8 @@ def get_autocast_device(dev):
def xformers_enabled(): def xformers_enabled():
global xpu_available global xpu_available
global directml_enabled global directml_enabled
if vram_state == VRAMState.CPU: global cpu_state
if cpu_state != CPUState.GPU:
return False return False
if xpu_available: if xpu_available:
return False return False
@ -380,12 +389,12 @@ def maximum_batch_area():
return int(max(area, 0)) return int(max(area, 0))
def cpu_mode(): def cpu_mode():
global vram_state global cpu_state
return vram_state == VRAMState.CPU return cpu_state == CPUState.CPU
def mps_mode(): def mps_mode():
global vram_state global cpu_state
return vram_state == VRAMState.MPS return cpu_state == CPUState.MPS
def should_use_fp16(): def should_use_fp16():
global xpu_available global xpu_available
@ -417,8 +426,8 @@ def should_use_fp16():
def soft_empty_cache(): def soft_empty_cache():
global xpu_available global xpu_available
global vram_state global cpu_state
if vram_state == VRAMState.MPS: if cpu_state == CPUState.MPS:
torch.mps.empty_cache() torch.mps.empty_cache()
elif xpu_available: elif xpu_available:
torch.xpu.empty_cache() torch.xpu.empty_cache()