mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Fix some lora loading slowdowns.
This commit is contained in:
parent
52a471c5c7
commit
517f4a94e4
@ -355,13 +355,14 @@ class ModelPatcher:
|
|||||||
|
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
def lowvram_load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False):
|
def lowvram_load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
|
||||||
mem_counter = 0
|
mem_counter = 0
|
||||||
patch_counter = 0
|
patch_counter = 0
|
||||||
lowvram_counter = 0
|
lowvram_counter = 0
|
||||||
for n, m in self.model.named_modules():
|
for n, m in self.model.named_modules():
|
||||||
lowvram_weight = False
|
lowvram_weight = False
|
||||||
if hasattr(m, "comfy_cast_weights"):
|
|
||||||
|
if not full_load and hasattr(m, "comfy_cast_weights"):
|
||||||
module_mem = comfy.model_management.module_size(m)
|
module_mem = comfy.model_management.module_size(m)
|
||||||
if mem_counter + module_mem >= lowvram_model_memory:
|
if mem_counter + module_mem >= lowvram_model_memory:
|
||||||
lowvram_weight = True
|
lowvram_weight = True
|
||||||
@ -401,8 +402,11 @@ class ModelPatcher:
|
|||||||
if weight.device == device_to:
|
if weight.device == device_to:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
self.patch_weight_to_device(weight_key) #TODO: speed this up without OOM
|
weight_to = None
|
||||||
self.patch_weight_to_device(bias_key)
|
if full_load:#TODO
|
||||||
|
weight_to = device_to
|
||||||
|
self.patch_weight_to_device(weight_key, device_to=weight_to) #TODO: speed this up without OOM
|
||||||
|
self.patch_weight_to_device(bias_key, device_to=weight_to)
|
||||||
m.to(device_to)
|
m.to(device_to)
|
||||||
logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
|
logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
|
||||||
|
|
||||||
@ -665,12 +669,13 @@ class ModelPatcher:
|
|||||||
return memory_freed
|
return memory_freed
|
||||||
|
|
||||||
def partially_load(self, device_to, extra_memory=0):
|
def partially_load(self, device_to, extra_memory=0):
|
||||||
|
full_load = False
|
||||||
if self.model.model_lowvram == False:
|
if self.model.model_lowvram == False:
|
||||||
return 0
|
return 0
|
||||||
if self.model.model_loaded_weight_memory + extra_memory > self.model_size():
|
if self.model.model_loaded_weight_memory + extra_memory > self.model_size():
|
||||||
pass #TODO: Full load
|
full_load = True
|
||||||
current_used = self.model.model_loaded_weight_memory
|
current_used = self.model.model_loaded_weight_memory
|
||||||
self.lowvram_load(device_to, lowvram_model_memory=current_used + extra_memory)
|
self.lowvram_load(device_to, lowvram_model_memory=current_used + extra_memory, full_load=full_load)
|
||||||
return self.model.model_loaded_weight_memory - current_used
|
return self.model.model_loaded_weight_memory - current_used
|
||||||
|
|
||||||
def current_loaded_device(self):
|
def current_loaded_device(self):
|
||||||
|
Loading…
Reference in New Issue
Block a user