diff --git a/comfy/sd.py b/comfy/sd.py index a74d1e82..fe60205d 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -400,7 +400,7 @@ class ControlNet: out.append(self.control_model) return out -def load_controlnet(ckpt_path): +def load_controlnet(ckpt_path, model=None): controlnet_data = load_torch_file(ckpt_path) pth_key = 'control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight' pth = False @@ -437,6 +437,21 @@ def load_controlnet(ckpt_path): use_fp16=use_fp16) if pth: + if 'difference' in controlnet_data: + if model is not None: + m = model.patch_model() + model_sd = m.state_dict() + for x in controlnet_data: + c_m = "control_model." + if x.startswith(c_m): + sd_key = "model.diffusion_model.{}".format(x[len(c_m):]) + if sd_key in model_sd: + cd = controlnet_data[x] + cd += model_sd[sd_key].type(cd.dtype).to(cd.device) + model.unpatch_model() + else: + print("WARNING: Loaded a diff controlnet without a model. It will very likely not work.") + class WeightsLoader(torch.nn.Module): pass w = WeightsLoader() diff --git a/nodes.py b/nodes.py index a41349cb..88fe664a 100644 --- a/nodes.py +++ b/nodes.py @@ -232,6 +232,24 @@ class ControlNetLoader: controlnet = comfy.sd.load_controlnet(controlnet_path) return (controlnet,) +class DiffControlNetLoader: + models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models") + controlnet_dir = os.path.join(models_dir, "controlnet") + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "control_net_name": (filter_files_extensions(recursive_search(s.controlnet_dir), supported_pt_extensions), )}} + + RETURN_TYPES = ("CONTROL_NET",) + FUNCTION = "load_controlnet" + + CATEGORY = "loaders" + + def load_controlnet(self, model, control_net_name): + controlnet_path = os.path.join(self.controlnet_dir, control_net_name) + controlnet = comfy.sd.load_controlnet(controlnet_path, model) + return (controlnet,) + class ControlNetApply: @classmethod @@ -770,6 +788,7 @@ NODE_CLASS_MAPPINGS = { "CLIPLoader": CLIPLoader, "ControlNetApply": ControlNetApply, "ControlNetLoader": ControlNetLoader, + "DiffControlNetLoader": DiffControlNetLoader, } CUSTOM_NODE_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "custom_nodes")