Make ComfyUI split batches a higher priority than weight offload.

This commit is contained in:
comfyanonymous 2024-08-01 16:39:59 -04:00
parent 1c61361fd2
commit d965474aaa
2 changed files with 10 additions and 4 deletions

View File

@ -379,11 +379,15 @@ def free_memory(memory_required, device, keep_loaded=[]):
if mem_free_torch > mem_free_total * 0.25: if mem_free_torch > mem_free_total * 0.25:
soft_empty_cache() soft_empty_cache()
def load_models_gpu(models, memory_required=0, force_patch_weights=False): def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None):
global vram_state global vram_state
inference_memory = minimum_inference_memory() inference_memory = minimum_inference_memory()
extra_mem = max(inference_memory, memory_required) extra_mem = max(inference_memory, memory_required)
if minimum_memory_required is None:
minimum_memory_required = extra_mem
else:
minimum_memory_required = max(inference_memory, minimum_memory_required)
models = set(models) models = set(models)
@ -446,8 +450,8 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False):
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM): if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM):
model_size = loaded_model.model_memory_required(torch_dev) model_size = loaded_model.model_memory_required(torch_dev)
current_free_mem = get_free_memory(torch_dev) current_free_mem = get_free_memory(torch_dev)
lowvram_model_memory = int(max(64 * (1024 * 1024), (current_free_mem - extra_mem))) lowvram_model_memory = int(max(64 * (1024 * 1024), (current_free_mem - minimum_memory_required)))
if model_size <= (current_free_mem - inference_memory): #only switch to lowvram if really necessary if model_size <= lowvram_model_memory: #only switch to lowvram if really necessary
lowvram_model_memory = 0 lowvram_model_memory = 0
if vram_set_state == VRAMState.NO_VRAM: if vram_set_state == VRAMState.NO_VRAM:

View File

@ -61,7 +61,9 @@ def prepare_sampling(model, noise_shape, conds):
device = model.load_device device = model.load_device
real_model = None real_model = None
models, inference_memory = get_additional_models(conds, model.model_dtype()) models, inference_memory = get_additional_models(conds, model.model_dtype())
comfy.model_management.load_models_gpu([model] + models, model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory) memory_required = model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory
minimum_memory_required = model.memory_required([noise_shape[0]] + list(noise_shape[1:])) + inference_memory
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required, minimum_memory_required=minimum_memory_required)
real_model = model.model real_model = model.model
return real_model, conds, models return real_model, conds, models