Add a get_injections function to ModelPatcher

This commit is contained in:
Jedrzej Kosinski 2025-01-06 20:34:30 -06:00
parent 1b38f5bf57
commit 58bf8815c8
2 changed files with 5 additions and 2 deletions

View File

@ -842,6 +842,9 @@ class ModelPatcher:
if key in self.injections:
self.injections.pop(key)
def get_injections(self, key: str):
return self.injections.get(key, None)
def set_additional_models(self, key: str, models: list['ModelPatcher']):
self.additional_models[key] = models

View File

@ -865,7 +865,7 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None):
if len(casts) == 0:
return
# Try to call .to on patches
# try to call .to on patches
if "patches" in to_load_options:
patches = to_load_options["patches"]
for name in patches:
@ -882,7 +882,7 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None):
if hasattr(patch_list[k], "to"):
for cast in casts:
patch_list[k] = patch_list[k].to(cast)
# Try to call .to on any wrappers/callbacks
# try to call .to on any wrappers/callbacks
wrappers_and_callbacks = ["wrappers", "callbacks"]
for wc_name in wrappers_and_callbacks:
if wc_name in to_load_options: