mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-20 11:23:29 +00:00
Fix lowvram issue with saving checkpoints.
The previous fix didn't cover the case where the model was loaded in lowvram mode right before.
This commit is contained in:
parent
49c20cdc70
commit
fa6dd7e5bb
@ -309,6 +309,11 @@ class LoadedModel:
|
|||||||
self.weights_loaded = True
|
self.weights_loaded = True
|
||||||
return self.real_model
|
return self.real_model
|
||||||
|
|
||||||
|
def should_reload_model(self, force_patch_weights=False):
|
||||||
|
if force_patch_weights and self.model.lowvram_patch_counter > 0:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
def model_unload(self, unpatch_weights=True):
|
def model_unload(self, unpatch_weights=True):
|
||||||
self.model.unpatch_model(self.model.offload_device, unpatch_weights=unpatch_weights)
|
self.model.unpatch_model(self.model.offload_device, unpatch_weights=unpatch_weights)
|
||||||
self.model.model_patches_to(self.model.offload_device)
|
self.model.model_patches_to(self.model.offload_device)
|
||||||
@ -391,10 +396,22 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False):
|
|||||||
models_already_loaded = []
|
models_already_loaded = []
|
||||||
for x in models:
|
for x in models:
|
||||||
loaded_model = LoadedModel(x)
|
loaded_model = LoadedModel(x)
|
||||||
|
loaded = None
|
||||||
|
|
||||||
if loaded_model in current_loaded_models:
|
try:
|
||||||
models_already_loaded.append(loaded_model)
|
loaded_model_index = current_loaded_models.index(loaded_model)
|
||||||
|
except:
|
||||||
|
loaded_model_index = None
|
||||||
|
|
||||||
|
if loaded_model_index is not None:
|
||||||
|
loaded = current_loaded_models[loaded_model_index]
|
||||||
|
if loaded.should_reload_model(force_patch_weights=force_patch_weights): #TODO: cleanup this model reload logic
|
||||||
|
current_loaded_models.pop(loaded_model_index).model_unload(unpatch_weights=True)
|
||||||
|
loaded = None
|
||||||
else:
|
else:
|
||||||
|
models_already_loaded.append(loaded)
|
||||||
|
|
||||||
|
if loaded is None:
|
||||||
if hasattr(x, "model"):
|
if hasattr(x, "model"):
|
||||||
logging.info(f"Requested to load {x.model.__class__.__name__}")
|
logging.info(f"Requested to load {x.model.__class__.__name__}")
|
||||||
models_to_load.append(loaded_model)
|
models_to_load.append(loaded_model)
|
||||||
|
@ -58,6 +58,7 @@ class ModelPatcher:
|
|||||||
|
|
||||||
self.weight_inplace_update = weight_inplace_update
|
self.weight_inplace_update = weight_inplace_update
|
||||||
self.model_lowvram = False
|
self.model_lowvram = False
|
||||||
|
self.lowvram_patch_counter = 0
|
||||||
self.patches_uuid = uuid.uuid4()
|
self.patches_uuid = uuid.uuid4()
|
||||||
|
|
||||||
def model_size(self):
|
def model_size(self):
|
||||||
@ -284,6 +285,7 @@ class ModelPatcher:
|
|||||||
return self.model_patcher.calculate_weight(self.model_patcher.patches[self.key], weight, self.key)
|
return self.model_patcher.calculate_weight(self.model_patcher.patches[self.key], weight, self.key)
|
||||||
|
|
||||||
mem_counter = 0
|
mem_counter = 0
|
||||||
|
patch_counter = 0
|
||||||
for n, m in self.model.named_modules():
|
for n, m in self.model.named_modules():
|
||||||
lowvram_weight = False
|
lowvram_weight = False
|
||||||
if hasattr(m, "comfy_cast_weights"):
|
if hasattr(m, "comfy_cast_weights"):
|
||||||
@ -300,11 +302,13 @@ class ModelPatcher:
|
|||||||
self.patch_weight_to_device(weight_key)
|
self.patch_weight_to_device(weight_key)
|
||||||
else:
|
else:
|
||||||
m.weight_function = LowVramPatch(weight_key, self)
|
m.weight_function = LowVramPatch(weight_key, self)
|
||||||
|
patch_counter += 1
|
||||||
if bias_key in self.patches:
|
if bias_key in self.patches:
|
||||||
if force_patch_weights:
|
if force_patch_weights:
|
||||||
self.patch_weight_to_device(bias_key)
|
self.patch_weight_to_device(bias_key)
|
||||||
else:
|
else:
|
||||||
m.bias_function = LowVramPatch(bias_key, self)
|
m.bias_function = LowVramPatch(bias_key, self)
|
||||||
|
patch_counter += 1
|
||||||
|
|
||||||
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
||||||
m.comfy_cast_weights = True
|
m.comfy_cast_weights = True
|
||||||
@ -317,6 +321,7 @@ class ModelPatcher:
|
|||||||
logging.debug("lowvram: loaded module regularly {}".format(m))
|
logging.debug("lowvram: loaded module regularly {}".format(m))
|
||||||
|
|
||||||
self.model_lowvram = True
|
self.model_lowvram = True
|
||||||
|
self.lowvram_patch_counter = patch_counter
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
def calculate_weight(self, patches, weight, key):
|
def calculate_weight(self, patches, weight, key):
|
||||||
@ -468,6 +473,7 @@ class ModelPatcher:
|
|||||||
m.bias_function = None
|
m.bias_function = None
|
||||||
|
|
||||||
self.model_lowvram = False
|
self.model_lowvram = False
|
||||||
|
self.lowvram_patch_counter = 0
|
||||||
|
|
||||||
keys = list(self.backup.keys())
|
keys = list(self.backup.keys())
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user