Fix some issues with TE being in lowvram mode.

This commit is contained in:
comfyanonymous 2024-08-12 23:42:21 -04:00
parent a562c17e8a
commit 74e124f4d7
3 changed files with 5 additions and 5 deletions

View File

@ -434,7 +434,7 @@ def free_memory(memory_required, device, keep_loaded=[]):
soft_empty_cache() soft_empty_cache()
return unloaded_models return unloaded_models
def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None): def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
global vram_state global vram_state
inference_memory = minimum_inference_memory() inference_memory = minimum_inference_memory()
@ -513,7 +513,7 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
else: else:
vram_set_state = vram_state vram_set_state = vram_state
lowvram_model_memory = 0 lowvram_model_memory = 0
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM): if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM) and not force_full_load:
model_size = loaded_model.model_memory_required(torch_dev) model_size = loaded_model.model_memory_required(torch_dev)
current_free_mem = get_free_memory(torch_dev) current_free_mem = get_free_memory(torch_dev)
lowvram_model_memory = max(64 * (1024 * 1024), (current_free_mem - minimum_memory_required), min(current_free_mem * 0.4, current_free_mem - minimum_inference_memory())) lowvram_model_memory = max(64 * (1024 * 1024), (current_free_mem - minimum_memory_required), min(current_free_mem * 0.4, current_free_mem - minimum_inference_memory()))

View File

@ -411,7 +411,7 @@ class ModelPatcher:
logging.debug("lowvram: loaded module regularly {} {}".format(n, m)) logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
if lowvram_counter > 0: if lowvram_counter > 0:
logging.info("loaded partially {} {}".format(lowvram_model_memory / (1024 * 1024), patch_counter)) logging.info("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter))
self.model.model_lowvram = True self.model.model_lowvram = True
else: else:
logging.info("loaded completely {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024))) logging.info("loaded completely {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024)))

View File

@ -86,7 +86,7 @@ class CLIP:
self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device) self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
if params['device'] == load_device: if params['device'] == load_device:
model_management.load_model_gpu(self.patcher) model_management.load_models_gpu([self.patcher], force_full_load=True)
self.layer_idx = None self.layer_idx = None
logging.debug("CLIP model load device: {}, offload device: {}, current: {}".format(load_device, offload_device, params['device'])) logging.debug("CLIP model load device: {}, offload device: {}, current: {}".format(load_device, offload_device, params['device']))
@ -585,7 +585,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device()) model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device())
if inital_load_device != torch.device("cpu"): if inital_load_device != torch.device("cpu"):
logging.info("loaded straight to GPU") logging.info("loaded straight to GPU")
model_management.load_model_gpu(model_patcher) model_management.load_models_gpu([model_patcher], force_full_load=True)
return (model_patcher, clip, vae, clipvision) return (model_patcher, clip, vae, clipvision)