Lower lora ram usage when in normal vram mode.

This commit is contained in:
comfyanonymous 2023-07-16 02:48:09 -04:00
parent 490771b7f4
commit 5f57362613
2 changed files with 19 additions and 13 deletions

View File

@ -233,10 +233,9 @@ def unload_model():
accelerate.hooks.remove_hook_from_submodules(current_loaded_model.model)
model_accelerated = False
current_loaded_model.unpatch_model()
current_loaded_model.model.to(current_loaded_model.offload_device)
current_loaded_model.model_patches_to(current_loaded_model.offload_device)
current_loaded_model.unpatch_model()
current_loaded_model = None
if vram_state != VRAMState.HIGH_VRAM:
soft_empty_cache()
@ -282,14 +281,6 @@ def load_model_gpu(model):
elif vram_set_state == VRAMState.NORMAL_VRAM or vram_set_state == VRAMState.HIGH_VRAM or vram_set_state == VRAMState.SHARED:
model_accelerated = False
real_model.to(torch_dev)
else:
if vram_set_state == VRAMState.NO_VRAM:
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"})
elif vram_set_state == VRAMState.LOW_VRAM:
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(lowvram_model_memory // (1024 * 1024)), "cpu": "16GiB"})
accelerate.dispatch_model(real_model, device_map=device_map, main_device=torch_dev)
model_accelerated = True
try:
real_model = model.patch_model()
@ -298,6 +289,15 @@ def load_model_gpu(model):
unload_model()
raise e
if vram_set_state == VRAMState.NO_VRAM:
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"})
accelerate.dispatch_model(real_model, device_map=device_map, main_device=torch_dev)
model_accelerated = True
elif vram_set_state == VRAMState.LOW_VRAM:
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(lowvram_model_memory // (1024 * 1024)), "cpu": "16GiB"})
accelerate.dispatch_model(real_model, device_map=device_map, main_device=torch_dev)
model_accelerated = True
return current_loaded_model
def load_controlnet_gpu(control_models):

View File

@ -428,11 +428,17 @@ class ModelPatcher:
return weight
def unpatch_model(self):
model_sd = self.model_state_dict()
keys = list(self.backup.keys())
def set_attr(obj, attr, value):
attrs = attr.split(".")
for name in attrs[:-1]:
obj = getattr(obj, name)
prev = getattr(obj, attrs[-1])
setattr(obj, attrs[-1], torch.nn.Parameter(value))
del prev
for k in keys:
model_sd[k][:] = self.backup[k]
del self.backup[k]
set_attr(self.model, k, self.backup[k])
self.backup = {}