Add function to get the list of currently loaded models.

This commit is contained in:
comfyanonymous 2024-06-05 19:14:56 -04:00
parent b1fd26fe9e
commit 104fcea0c8

View File

@ -276,6 +276,7 @@ class LoadedModel:
self.device = model.load_device
self.weights_loaded = False
self.real_model = None
self.currently_used = True
def model_memory(self):
return self.model.model_size()
@ -365,6 +366,7 @@ def free_memory(memory_required, device, keep_loaded=[]):
if shift_model.device == device:
if shift_model not in keep_loaded:
can_unload.append((sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
shift_model.currently_used = False
for x in sorted(can_unload):
i = x[-1]
@ -410,6 +412,7 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False):
current_loaded_models.pop(loaded_model_index).model_unload(unpatch_weights=True)
loaded = None
else:
loaded.currently_used = True
models_already_loaded.append(loaded)
if loaded is None:
@ -466,6 +469,16 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False):
def load_model_gpu(model):
return load_models_gpu([model])
def loaded_models(only_currently_used=False):
output = []
for m in current_loaded_models:
if only_currently_used:
if not m.currently_used:
continue
output.append(m.model)
return output
def cleanup_models(keep_clone_weights_loaded=False):
to_delete = []
for i in range(len(current_loaded_models)):