I don't think controlnets were being handled correctly by MPS.

This commit is contained in:
comfyanonymous 2023-03-24 14:30:43 -04:00
parent 3c6ff8821c
commit 4adcea7228

View File

@ -62,8 +62,7 @@ if "--novram" in sys.argv:
set_vram_to = NO_VRAM set_vram_to = NO_VRAM
if "--highvram" in sys.argv: if "--highvram" in sys.argv:
vram_state = HIGH_VRAM vram_state = HIGH_VRAM
if torch.backends.mps.is_available():
vram_state = MPS
if set_vram_to == LOW_VRAM or set_vram_to == NO_VRAM: if set_vram_to == LOW_VRAM or set_vram_to == NO_VRAM:
try: try:
@ -78,6 +77,12 @@ if set_vram_to == LOW_VRAM or set_vram_to == NO_VRAM:
total_vram_available_mb = (total_vram - 1024) // 2 total_vram_available_mb = (total_vram - 1024) // 2
total_vram_available_mb = int(max(256, total_vram_available_mb)) total_vram_available_mb = int(max(256, total_vram_available_mb))
try:
if torch.backends.mps.is_available():
vram_state = MPS
except:
pass
if "--cpu" in sys.argv: if "--cpu" in sys.argv:
vram_state = CPU vram_state = CPU
@ -153,9 +158,6 @@ def load_controlnet_gpu(models):
if vram_state == CPU: if vram_state == CPU:
return return
if vram_state == MPS:
return
if vram_state == LOW_VRAM or vram_state == NO_VRAM: if vram_state == LOW_VRAM or vram_state == NO_VRAM:
#don't load controlnets like this if low vram because they will be loaded right before running and unloaded right after #don't load controlnets like this if low vram because they will be loaded right before running and unloaded right after
return return
@ -164,9 +166,10 @@ def load_controlnet_gpu(models):
if m not in models: if m not in models:
m.cpu() m.cpu()
device = get_torch_device()
current_gpu_controlnets = [] current_gpu_controlnets = []
for m in models: for m in models:
current_gpu_controlnets.append(m.cuda()) current_gpu_controlnets.append(m.to(device))
def load_if_low_vram(model): def load_if_low_vram(model):