mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-10 18:05:16 +00:00
Merge T2IAdapterLoader and ControlNetLoader.
Workflows will be auto updated.
This commit is contained in:
parent
e1a9e26968
commit
2e73367f45
13
comfy/sd.py
13
comfy/sd.py
@ -527,8 +527,10 @@ def load_controlnet(ckpt_path, model=None):
|
||||
elif key in controlnet_data:
|
||||
pass
|
||||
else:
|
||||
print("error checkpoint does not contain controlnet data", ckpt_path)
|
||||
return None
|
||||
net = load_t2i_adapter(controlnet_data)
|
||||
if net is None:
|
||||
print("error checkpoint does not contain controlnet or t2i adapter data", ckpt_path)
|
||||
return net
|
||||
|
||||
context_dim = controlnet_data[key].shape[1]
|
||||
|
||||
@ -682,15 +684,16 @@ class T2IAdapter:
|
||||
out += self.previous_controlnet.get_control_models()
|
||||
return out
|
||||
|
||||
def load_t2i_adapter(ckpt_path, model=None):
|
||||
t2i_data = load_torch_file(ckpt_path)
|
||||
def load_t2i_adapter(t2i_data):
|
||||
keys = t2i_data.keys()
|
||||
if "body.0.in_conv.weight" in keys:
|
||||
cin = t2i_data['body.0.in_conv.weight'].shape[1]
|
||||
model_ad = adapter.Adapter_light(cin=cin, channels=[320, 640, 1280, 1280], nums_rb=4)
|
||||
else:
|
||||
elif 'conv_in.weight' in keys:
|
||||
cin = t2i_data['conv_in.weight'].shape[1]
|
||||
model_ad = adapter.Adapter(cin=cin, channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True, use_conv=False)
|
||||
else:
|
||||
return None
|
||||
model_ad.load_state_dict(t2i_data)
|
||||
return T2IAdapter(model_ad, cin // 64)
|
||||
|
||||
|
@ -2,7 +2,6 @@ import os
|
||||
from comfy_extras.chainner_models import model_loading
|
||||
from comfy.sd import load_torch_file
|
||||
import model_management
|
||||
from nodes import filter_files_extensions, recursive_search, supported_ckpt_extensions
|
||||
import torch
|
||||
import comfy.utils
|
||||
import folder_paths
|
||||
|
38
nodes.py
38
nodes.py
@ -24,26 +24,6 @@ import model_management
|
||||
import importlib
|
||||
|
||||
import folder_paths
|
||||
supported_ckpt_extensions = ['.ckpt', '.pth']
|
||||
supported_pt_extensions = ['.ckpt', '.pt', '.bin', '.pth']
|
||||
try:
|
||||
import safetensors.torch
|
||||
supported_ckpt_extensions += ['.safetensors']
|
||||
supported_pt_extensions += ['.safetensors']
|
||||
except:
|
||||
print("Could not import safetensors, safetensors support disabled.")
|
||||
|
||||
def recursive_search(directory):
|
||||
result = []
|
||||
for root, subdir, file in os.walk(directory, followlinks=True):
|
||||
for filepath in file:
|
||||
#we os.path,join directory with a blank string to generate a path separator at the end.
|
||||
result.append(os.path.join(root, filepath).replace(os.path.join(directory,''),''))
|
||||
return result
|
||||
|
||||
def filter_files_extensions(files, extensions):
|
||||
return sorted(list(filter(lambda a: os.path.splitext(a)[-1].lower() in extensions, files)))
|
||||
|
||||
|
||||
def before_node_execution():
|
||||
model_management.throw_exception_if_processing_interrupted()
|
||||
@ -348,23 +328,6 @@ class ControlNetApply:
|
||||
c.append(n)
|
||||
return (c, )
|
||||
|
||||
class T2IAdapterLoader:
|
||||
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
|
||||
t2i_adapter_dir = os.path.join(models_dir, "t2i_adapter")
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "t2i_adapter_name": (filter_files_extensions(recursive_search(s.t2i_adapter_dir), supported_pt_extensions), )}}
|
||||
|
||||
RETURN_TYPES = ("CONTROL_NET",)
|
||||
FUNCTION = "load_t2i_adapter"
|
||||
|
||||
CATEGORY = "loaders"
|
||||
|
||||
def load_t2i_adapter(self, t2i_adapter_name):
|
||||
t2i_path = os.path.join(self.t2i_adapter_dir, t2i_adapter_name)
|
||||
t2i_adapter = comfy.sd.load_t2i_adapter(t2i_path)
|
||||
return (t2i_adapter,)
|
||||
|
||||
class CLIPLoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@ -963,7 +926,6 @@ NODE_CLASS_MAPPINGS = {
|
||||
"ControlNetApply": ControlNetApply,
|
||||
"ControlNetLoader": ControlNetLoader,
|
||||
"DiffControlNetLoader": DiffControlNetLoader,
|
||||
"T2IAdapterLoader": T2IAdapterLoader,
|
||||
"StyleModelLoader": StyleModelLoader,
|
||||
"CLIPVisionLoader": CLIPVisionLoader,
|
||||
"VAEDecodeTiled": VAEDecodeTiled,
|
||||
|
@ -614,6 +614,12 @@ class ComfyApp {
|
||||
if (!graphData) {
|
||||
graphData = defaultGraph;
|
||||
}
|
||||
|
||||
// Patch T2IAdapterLoader to ControlNetLoader since they are the same node now
|
||||
for (let n of graphData.nodes) {
|
||||
if (n.type == "T2IAdapterLoader") n.type = "ControlNetLoader";
|
||||
}
|
||||
|
||||
this.graph.configure(graphData);
|
||||
|
||||
for (const node of this.graph._nodes) {
|
||||
|
Loading…
Reference in New Issue
Block a user