mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Lower lora ram usage when in normal vram mode.
This commit is contained in:
parent
490771b7f4
commit
5f57362613
@ -233,10 +233,9 @@ def unload_model():
|
||||
accelerate.hooks.remove_hook_from_submodules(current_loaded_model.model)
|
||||
model_accelerated = False
|
||||
|
||||
|
||||
current_loaded_model.unpatch_model()
|
||||
current_loaded_model.model.to(current_loaded_model.offload_device)
|
||||
current_loaded_model.model_patches_to(current_loaded_model.offload_device)
|
||||
current_loaded_model.unpatch_model()
|
||||
current_loaded_model = None
|
||||
if vram_state != VRAMState.HIGH_VRAM:
|
||||
soft_empty_cache()
|
||||
@ -282,14 +281,6 @@ def load_model_gpu(model):
|
||||
elif vram_set_state == VRAMState.NORMAL_VRAM or vram_set_state == VRAMState.HIGH_VRAM or vram_set_state == VRAMState.SHARED:
|
||||
model_accelerated = False
|
||||
real_model.to(torch_dev)
|
||||
else:
|
||||
if vram_set_state == VRAMState.NO_VRAM:
|
||||
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"})
|
||||
elif vram_set_state == VRAMState.LOW_VRAM:
|
||||
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(lowvram_model_memory // (1024 * 1024)), "cpu": "16GiB"})
|
||||
|
||||
accelerate.dispatch_model(real_model, device_map=device_map, main_device=torch_dev)
|
||||
model_accelerated = True
|
||||
|
||||
try:
|
||||
real_model = model.patch_model()
|
||||
@ -298,6 +289,15 @@ def load_model_gpu(model):
|
||||
unload_model()
|
||||
raise e
|
||||
|
||||
if vram_set_state == VRAMState.NO_VRAM:
|
||||
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"})
|
||||
accelerate.dispatch_model(real_model, device_map=device_map, main_device=torch_dev)
|
||||
model_accelerated = True
|
||||
elif vram_set_state == VRAMState.LOW_VRAM:
|
||||
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(lowvram_model_memory // (1024 * 1024)), "cpu": "16GiB"})
|
||||
accelerate.dispatch_model(real_model, device_map=device_map, main_device=torch_dev)
|
||||
model_accelerated = True
|
||||
|
||||
return current_loaded_model
|
||||
|
||||
def load_controlnet_gpu(control_models):
|
||||
|
12
comfy/sd.py
12
comfy/sd.py
@ -428,11 +428,17 @@ class ModelPatcher:
|
||||
return weight
|
||||
|
||||
def unpatch_model(self):
|
||||
model_sd = self.model_state_dict()
|
||||
keys = list(self.backup.keys())
|
||||
def set_attr(obj, attr, value):
|
||||
attrs = attr.split(".")
|
||||
for name in attrs[:-1]:
|
||||
obj = getattr(obj, name)
|
||||
prev = getattr(obj, attrs[-1])
|
||||
setattr(obj, attrs[-1], torch.nn.Parameter(value))
|
||||
del prev
|
||||
|
||||
for k in keys:
|
||||
model_sd[k][:] = self.backup[k]
|
||||
del self.backup[k]
|
||||
set_attr(self.model, k, self.backup[k])
|
||||
|
||||
self.backup = {}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user