mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-20 11:23:29 +00:00
Small refactor.
This commit is contained in:
parent
28a7205739
commit
bceccca0e5
@ -129,7 +129,6 @@ def load_model_gpu(model):
|
|||||||
global current_loaded_model
|
global current_loaded_model
|
||||||
global vram_state
|
global vram_state
|
||||||
global model_accelerated
|
global model_accelerated
|
||||||
global xpu_available
|
|
||||||
|
|
||||||
if model is current_loaded_model:
|
if model is current_loaded_model:
|
||||||
return
|
return
|
||||||
@ -148,17 +147,14 @@ def load_model_gpu(model):
|
|||||||
pass
|
pass
|
||||||
elif vram_state == VRAMState.NORMAL_VRAM or vram_state == VRAMState.HIGH_VRAM:
|
elif vram_state == VRAMState.NORMAL_VRAM or vram_state == VRAMState.HIGH_VRAM:
|
||||||
model_accelerated = False
|
model_accelerated = False
|
||||||
if xpu_available:
|
real_model.to(get_torch_device())
|
||||||
real_model.to("xpu")
|
|
||||||
else:
|
|
||||||
real_model.cuda()
|
|
||||||
else:
|
else:
|
||||||
if vram_state == VRAMState.NO_VRAM:
|
if vram_state == VRAMState.NO_VRAM:
|
||||||
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"})
|
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"})
|
||||||
elif vram_state == VRAMState.LOW_VRAM:
|
elif vram_state == VRAMState.LOW_VRAM:
|
||||||
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(total_vram_available_mb), "cpu": "16GiB"})
|
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(total_vram_available_mb), "cpu": "16GiB"})
|
||||||
|
|
||||||
accelerate.dispatch_model(real_model, device_map=device_map, main_device="xpu" if xpu_available else "cuda")
|
accelerate.dispatch_model(real_model, device_map=device_map, main_device=get_torch_device())
|
||||||
model_accelerated = True
|
model_accelerated = True
|
||||||
return current_loaded_model
|
return current_loaded_model
|
||||||
|
|
||||||
@ -184,12 +180,8 @@ def load_controlnet_gpu(models):
|
|||||||
|
|
||||||
def load_if_low_vram(model):
|
def load_if_low_vram(model):
|
||||||
global vram_state
|
global vram_state
|
||||||
global xpu_available
|
|
||||||
if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM:
|
if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM:
|
||||||
if xpu_available:
|
return model.to(get_torch_device())
|
||||||
return model.to("xpu")
|
|
||||||
else:
|
|
||||||
return model.cuda()
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def unload_if_low_vram(model):
|
def unload_if_low_vram(model):
|
||||||
|
Loading…
Reference in New Issue
Block a user