Small refactor.

This commit is contained in:
comfyanonymous 2023-04-06 23:52:34 -04:00
parent 28a7205739
commit bceccca0e5

View File

@ -129,7 +129,6 @@ def load_model_gpu(model):
global current_loaded_model
global vram_state
global model_accelerated
global xpu_available
if model is current_loaded_model:
return
@ -148,17 +147,14 @@ def load_model_gpu(model):
pass
elif vram_state == VRAMState.NORMAL_VRAM or vram_state == VRAMState.HIGH_VRAM:
model_accelerated = False
if xpu_available:
real_model.to("xpu")
else:
real_model.cuda()
real_model.to(get_torch_device())
else:
if vram_state == VRAMState.NO_VRAM:
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"})
elif vram_state == VRAMState.LOW_VRAM:
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(total_vram_available_mb), "cpu": "16GiB"})
accelerate.dispatch_model(real_model, device_map=device_map, main_device="xpu" if xpu_available else "cuda")
accelerate.dispatch_model(real_model, device_map=device_map, main_device=get_torch_device())
model_accelerated = True
return current_loaded_model
@ -184,12 +180,8 @@ def load_controlnet_gpu(models):
def load_if_low_vram(model):
global vram_state
global xpu_available
if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM:
if xpu_available:
return model.to("xpu")
else:
return model.cuda()
return model.to(get_torch_device())
return model
def unload_if_low_vram(model):