mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-20 11:23:29 +00:00
Fix issue with lowvram mode breaking model saving.
This commit is contained in:
parent
4f63ee99f1
commit
e1489ad257
@ -285,7 +285,7 @@ class LoadedModel:
|
|||||||
else:
|
else:
|
||||||
return self.model_memory()
|
return self.model_memory()
|
||||||
|
|
||||||
def model_load(self, lowvram_model_memory=0):
|
def model_load(self, lowvram_model_memory=0, force_patch_weights=False):
|
||||||
patch_model_to = self.device
|
patch_model_to = self.device
|
||||||
|
|
||||||
self.model.model_patches_to(self.device)
|
self.model.model_patches_to(self.device)
|
||||||
@ -295,7 +295,7 @@ class LoadedModel:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
if lowvram_model_memory > 0 and load_weights:
|
if lowvram_model_memory > 0 and load_weights:
|
||||||
self.real_model = self.model.patch_model_lowvram(device_to=patch_model_to, lowvram_model_memory=lowvram_model_memory)
|
self.real_model = self.model.patch_model_lowvram(device_to=patch_model_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights)
|
||||||
else:
|
else:
|
||||||
self.real_model = self.model.patch_model(device_to=patch_model_to, patch_weights=load_weights)
|
self.real_model = self.model.patch_model(device_to=patch_model_to, patch_weights=load_weights)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -379,7 +379,7 @@ def free_memory(memory_required, device, keep_loaded=[]):
|
|||||||
if mem_free_torch > mem_free_total * 0.25:
|
if mem_free_torch > mem_free_total * 0.25:
|
||||||
soft_empty_cache()
|
soft_empty_cache()
|
||||||
|
|
||||||
def load_models_gpu(models, memory_required=0):
|
def load_models_gpu(models, memory_required=0, force_patch_weights=False):
|
||||||
global vram_state
|
global vram_state
|
||||||
|
|
||||||
inference_memory = minimum_inference_memory()
|
inference_memory = minimum_inference_memory()
|
||||||
@ -444,7 +444,7 @@ def load_models_gpu(models, memory_required=0):
|
|||||||
if vram_set_state == VRAMState.NO_VRAM:
|
if vram_set_state == VRAMState.NO_VRAM:
|
||||||
lowvram_model_memory = 64 * 1024 * 1024
|
lowvram_model_memory = 64 * 1024 * 1024
|
||||||
|
|
||||||
cur_loaded_model = loaded_model.model_load(lowvram_model_memory)
|
cur_loaded_model = loaded_model.model_load(lowvram_model_memory, force_patch_weights=force_patch_weights)
|
||||||
current_loaded_models.insert(0, loaded_model)
|
current_loaded_models.insert(0, loaded_model)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -272,7 +272,7 @@ class ModelPatcher:
|
|||||||
|
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
def patch_model_lowvram(self, device_to=None, lowvram_model_memory=0):
|
def patch_model_lowvram(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False):
|
||||||
self.patch_model(device_to, patch_weights=False)
|
self.patch_model(device_to, patch_weights=False)
|
||||||
|
|
||||||
logging.info("loading in lowvram mode {}".format(lowvram_model_memory/(1024 * 1024)))
|
logging.info("loading in lowvram mode {}".format(lowvram_model_memory/(1024 * 1024)))
|
||||||
@ -296,9 +296,15 @@ class ModelPatcher:
|
|||||||
|
|
||||||
if lowvram_weight:
|
if lowvram_weight:
|
||||||
if weight_key in self.patches:
|
if weight_key in self.patches:
|
||||||
m.weight_function = LowVramPatch(weight_key, self)
|
if force_patch_weights:
|
||||||
|
self.patch_weight_to_device(weight_key)
|
||||||
|
else:
|
||||||
|
m.weight_function = LowVramPatch(weight_key, self)
|
||||||
if bias_key in self.patches:
|
if bias_key in self.patches:
|
||||||
m.bias_function = LowVramPatch(bias_key, self)
|
if force_patch_weights:
|
||||||
|
self.patch_weight_to_device(bias_key)
|
||||||
|
else:
|
||||||
|
m.bias_function = LowVramPatch(bias_key, self)
|
||||||
|
|
||||||
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
||||||
m.comfy_cast_weights = True
|
m.comfy_cast_weights = True
|
||||||
|
@ -562,7 +562,7 @@ def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, m
|
|||||||
load_models.append(clip.load_model())
|
load_models.append(clip.load_model())
|
||||||
clip_sd = clip.get_sd()
|
clip_sd = clip.get_sd()
|
||||||
|
|
||||||
model_management.load_models_gpu(load_models)
|
model_management.load_models_gpu(load_models, force_patch_weights=True)
|
||||||
clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None
|
clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None
|
||||||
sd = model.model.state_dict_for_saving(clip_sd, vae.get_sd(), clip_vision_sd)
|
sd = model.model.state_dict_for_saving(clip_sd, vae.get_sd(), clip_vision_sd)
|
||||||
for k in extra_keys:
|
for k in extra_keys:
|
||||||
|
@ -262,7 +262,7 @@ class CLIPSave:
|
|||||||
for x in extra_pnginfo:
|
for x in extra_pnginfo:
|
||||||
metadata[x] = json.dumps(extra_pnginfo[x])
|
metadata[x] = json.dumps(extra_pnginfo[x])
|
||||||
|
|
||||||
comfy.model_management.load_models_gpu([clip.load_model()])
|
comfy.model_management.load_models_gpu([clip.load_model()], force_patch_weights=True)
|
||||||
clip_sd = clip.get_sd()
|
clip_sd = clip.get_sd()
|
||||||
|
|
||||||
for prefix in ["clip_l.", "clip_g.", ""]:
|
for prefix in ["clip_l.", "clip_g.", ""]:
|
||||||
|
Loading…
Reference in New Issue
Block a user