mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-20 03:13:30 +00:00
Add load_controlnet_state_dict function.
This commit is contained in:
parent
9f7e9f0547
commit
2d810b081e
@ -495,8 +495,8 @@ def convert_mistoline(sd):
|
|||||||
return comfy.utils.state_dict_prefix_replace(sd, {"single_controlnet_blocks.": "controlnet_single_blocks."})
|
return comfy.utils.state_dict_prefix_replace(sd, {"single_controlnet_blocks.": "controlnet_single_blocks."})
|
||||||
|
|
||||||
|
|
||||||
def load_controlnet(ckpt_path, model=None, model_options={}):
|
def load_controlnet_state_dict(state_dict, model=None, model_options={}):
|
||||||
controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True)
|
controlnet_data = state_dict
|
||||||
if 'after_proj_list.18.bias' in controlnet_data.keys(): #Hunyuan DiT
|
if 'after_proj_list.18.bias' in controlnet_data.keys(): #Hunyuan DiT
|
||||||
return load_controlnet_hunyuandit(controlnet_data, model_options=model_options)
|
return load_controlnet_hunyuandit(controlnet_data, model_options=model_options)
|
||||||
|
|
||||||
@ -578,7 +578,7 @@ def load_controlnet(ckpt_path, model=None, model_options={}):
|
|||||||
else:
|
else:
|
||||||
net = load_t2i_adapter(controlnet_data, model_options=model_options)
|
net = load_t2i_adapter(controlnet_data, model_options=model_options)
|
||||||
if net is None:
|
if net is None:
|
||||||
logging.error("error checkpoint does not contain controlnet or t2i adapter data {}".format(ckpt_path))
|
logging.error("error could not detect control model type.")
|
||||||
return net
|
return net
|
||||||
|
|
||||||
if controlnet_config is None:
|
if controlnet_config is None:
|
||||||
@ -633,14 +633,21 @@ def load_controlnet(ckpt_path, model=None, model_options={}):
|
|||||||
if len(unexpected) > 0:
|
if len(unexpected) > 0:
|
||||||
logging.debug("unexpected controlnet keys: {}".format(unexpected))
|
logging.debug("unexpected controlnet keys: {}".format(unexpected))
|
||||||
|
|
||||||
global_average_pooling = False
|
global_average_pooling = model_options.get("global_average_pooling", False)
|
||||||
filename = os.path.splitext(ckpt_path)[0]
|
|
||||||
if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling
|
|
||||||
global_average_pooling = True
|
|
||||||
|
|
||||||
control = ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
|
control = ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
|
||||||
return control
|
return control
|
||||||
|
|
||||||
|
def load_controlnet(ckpt_path, model=None, model_options={}):
|
||||||
|
if "global_average_pooling" not in model_options:
|
||||||
|
filename = os.path.splitext(ckpt_path)[0]
|
||||||
|
if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling
|
||||||
|
model_options["global_average_pooling"] = True
|
||||||
|
|
||||||
|
cnet = load_controlnet_state_dict(comfy.utils.load_torch_file(ckpt_path, safe_load=True), model=model, model_options=model_options)
|
||||||
|
if cnet is None:
|
||||||
|
logging.error("error checkpoint does not contain controlnet or t2i adapter data {}".format(ckpt_path))
|
||||||
|
return cnet
|
||||||
|
|
||||||
class T2IAdapter(ControlBase):
|
class T2IAdapter(ControlBase):
|
||||||
def __init__(self, t2i_model, channels_in, compression_ratio, upscale_algorithm, device=None):
|
def __init__(self, t2i_model, channels_in, compression_ratio, upscale_algorithm, device=None):
|
||||||
super().__init__(device)
|
super().__init__(device)
|
||||||
|
Loading…
Reference in New Issue
Block a user