Merge T2IAdapterLoader and ControlNetLoader.

Workflows will be auto updated.
This commit is contained in:
comfyanonymous 2023-03-17 18:17:59 -04:00
parent e1a9e26968
commit 2e73367f45
6 changed files with 14 additions and 44 deletions

View File

@ -527,8 +527,10 @@ def load_controlnet(ckpt_path, model=None):
elif key in controlnet_data:
pass
else:
print("error checkpoint does not contain controlnet data", ckpt_path)
return None
net = load_t2i_adapter(controlnet_data)
if net is None:
print("error checkpoint does not contain controlnet or t2i adapter data", ckpt_path)
return net
context_dim = controlnet_data[key].shape[1]
@ -682,15 +684,16 @@ class T2IAdapter:
out += self.previous_controlnet.get_control_models()
return out
def load_t2i_adapter(ckpt_path, model=None):
t2i_data = load_torch_file(ckpt_path)
def load_t2i_adapter(t2i_data):
keys = t2i_data.keys()
if "body.0.in_conv.weight" in keys:
cin = t2i_data['body.0.in_conv.weight'].shape[1]
model_ad = adapter.Adapter_light(cin=cin, channels=[320, 640, 1280, 1280], nums_rb=4)
else:
elif 'conv_in.weight' in keys:
cin = t2i_data['conv_in.weight'].shape[1]
model_ad = adapter.Adapter(cin=cin, channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True, use_conv=False)
else:
return None
model_ad.load_state_dict(t2i_data)
return T2IAdapter(model_ad, cin // 64)

View File

@ -2,7 +2,6 @@ import os
from comfy_extras.chainner_models import model_loading
from comfy.sd import load_torch_file
import model_management
from nodes import filter_files_extensions, recursive_search, supported_ckpt_extensions
import torch
import comfy.utils
import folder_paths

View File

@ -24,26 +24,6 @@ import model_management
import importlib
import folder_paths
supported_ckpt_extensions = ['.ckpt', '.pth']
supported_pt_extensions = ['.ckpt', '.pt', '.bin', '.pth']
try:
import safetensors.torch
supported_ckpt_extensions += ['.safetensors']
supported_pt_extensions += ['.safetensors']
except:
print("Could not import safetensors, safetensors support disabled.")
def recursive_search(directory):
result = []
for root, subdir, file in os.walk(directory, followlinks=True):
for filepath in file:
#we os.path,join directory with a blank string to generate a path separator at the end.
result.append(os.path.join(root, filepath).replace(os.path.join(directory,''),''))
return result
def filter_files_extensions(files, extensions):
return sorted(list(filter(lambda a: os.path.splitext(a)[-1].lower() in extensions, files)))
def before_node_execution():
model_management.throw_exception_if_processing_interrupted()
@ -348,23 +328,6 @@ class ControlNetApply:
c.append(n)
return (c, )
class T2IAdapterLoader:
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
t2i_adapter_dir = os.path.join(models_dir, "t2i_adapter")
@classmethod
def INPUT_TYPES(s):
return {"required": { "t2i_adapter_name": (filter_files_extensions(recursive_search(s.t2i_adapter_dir), supported_pt_extensions), )}}
RETURN_TYPES = ("CONTROL_NET",)
FUNCTION = "load_t2i_adapter"
CATEGORY = "loaders"
def load_t2i_adapter(self, t2i_adapter_name):
t2i_path = os.path.join(self.t2i_adapter_dir, t2i_adapter_name)
t2i_adapter = comfy.sd.load_t2i_adapter(t2i_path)
return (t2i_adapter,)
class CLIPLoader:
@classmethod
def INPUT_TYPES(s):
@ -963,7 +926,6 @@ NODE_CLASS_MAPPINGS = {
"ControlNetApply": ControlNetApply,
"ControlNetLoader": ControlNetLoader,
"DiffControlNetLoader": DiffControlNetLoader,
"T2IAdapterLoader": T2IAdapterLoader,
"StyleModelLoader": StyleModelLoader,
"CLIPVisionLoader": CLIPVisionLoader,
"VAEDecodeTiled": VAEDecodeTiled,

View File

@ -614,6 +614,12 @@ class ComfyApp {
if (!graphData) {
graphData = defaultGraph;
}
// Patch T2IAdapterLoader to ControlNetLoader since they are the same node now
for (let n of graphData.nodes) {
if (n.type == "T2IAdapterLoader") n.type = "ControlNetLoader";
}
this.graph.configure(graphData);
for (const node of this.graph._nodes) {