diff --git a/comfy_extras/nodes_model_merging.py b/comfy_extras/nodes_model_merging.py index bce4b3dd..ebcbd4be 100644 --- a/comfy_extras/nodes_model_merging.py +++ b/comfy_extras/nodes_model_merging.py @@ -27,6 +27,44 @@ class ModelMergeSimple: m.add_patches({k: kp[k]}, 1.0 - ratio, ratio) return (m, ) +class ModelSubtract: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model1": ("MODEL",), + "model2": ("MODEL",), + "multiplier": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "merge" + + CATEGORY = "_for_testing/model_merging" + + def merge(self, model1, model2, multiplier): + m = model1.clone() + kp = model2.get_key_patches("diffusion_model.") + for k in kp: + m.add_patches({k: kp[k]}, - multiplier, multiplier) + return (m, ) + +class ModelAdd: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model1": ("MODEL",), + "model2": ("MODEL",), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "merge" + + CATEGORY = "_for_testing/model_merging" + + def merge(self, model1, model2): + m = model1.clone() + kp = model2.get_key_patches("diffusion_model.") + for k in kp: + m.add_patches({k: kp[k]}, 1.0, 1.0) + return (m, ) + + class CLIPMergeSimple: @classmethod def INPUT_TYPES(s): @@ -144,6 +182,8 @@ class CheckpointSave: NODE_CLASS_MAPPINGS = { "ModelMergeSimple": ModelMergeSimple, "ModelMergeBlocks": ModelMergeBlocks, + "ModelMergeSubtract": ModelSubtract, + "ModelMergeAdd": ModelAdd, "CheckpointSave": CheckpointSave, "CLIPMergeSimple": CLIPMergeSimple, }