diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 51a62e04..7d7977c1 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -842,6 +842,9 @@ class ModelPatcher: if key in self.injections: self.injections.pop(key) + def get_injections(self, key: str): + return self.injections.get(key, None) + def set_additional_models(self, key: str, models: list['ModelPatcher']): self.additional_models[key] = models diff --git a/comfy/samplers.py b/comfy/samplers.py index a725d518..5cc33a7d 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -865,7 +865,7 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None): if len(casts) == 0: return - # Try to call .to on patches + # try to call .to on patches if "patches" in to_load_options: patches = to_load_options["patches"] for name in patches: @@ -882,7 +882,7 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None): if hasattr(patch_list[k], "to"): for cast in casts: patch_list[k] = patch_list[k].to(cast) - # Try to call .to on any wrappers/callbacks + # try to call .to on any wrappers/callbacks wrappers_and_callbacks = ["wrappers", "callbacks"] for wc_name in wrappers_and_callbacks: if wc_name in to_load_options: