mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Small refactor.
This commit is contained in:
parent
28a7205739
commit
bceccca0e5
@ -129,7 +129,6 @@ def load_model_gpu(model):
|
||||
global current_loaded_model
|
||||
global vram_state
|
||||
global model_accelerated
|
||||
global xpu_available
|
||||
|
||||
if model is current_loaded_model:
|
||||
return
|
||||
@ -148,17 +147,14 @@ def load_model_gpu(model):
|
||||
pass
|
||||
elif vram_state == VRAMState.NORMAL_VRAM or vram_state == VRAMState.HIGH_VRAM:
|
||||
model_accelerated = False
|
||||
if xpu_available:
|
||||
real_model.to("xpu")
|
||||
else:
|
||||
real_model.cuda()
|
||||
real_model.to(get_torch_device())
|
||||
else:
|
||||
if vram_state == VRAMState.NO_VRAM:
|
||||
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"})
|
||||
elif vram_state == VRAMState.LOW_VRAM:
|
||||
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(total_vram_available_mb), "cpu": "16GiB"})
|
||||
|
||||
accelerate.dispatch_model(real_model, device_map=device_map, main_device="xpu" if xpu_available else "cuda")
|
||||
accelerate.dispatch_model(real_model, device_map=device_map, main_device=get_torch_device())
|
||||
model_accelerated = True
|
||||
return current_loaded_model
|
||||
|
||||
@ -184,12 +180,8 @@ def load_controlnet_gpu(models):
|
||||
|
||||
def load_if_low_vram(model):
|
||||
global vram_state
|
||||
global xpu_available
|
||||
if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM:
|
||||
if xpu_available:
|
||||
return model.to("xpu")
|
||||
else:
|
||||
return model.cuda()
|
||||
return model.to(get_torch_device())
|
||||
return model
|
||||
|
||||
def unload_if_low_vram(model):
|
||||
|
Loading…
Reference in New Issue
Block a user