mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
More generic unet prefix detection code.
This commit is contained in:
parent
14764aa2e2
commit
334ba48cea
@ -261,13 +261,22 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
|
|||||||
return model_config
|
return model_config
|
||||||
|
|
||||||
def unet_prefix_from_state_dict(state_dict):
|
def unet_prefix_from_state_dict(state_dict):
|
||||||
if "model.model.postprocess_conv.weight" in state_dict: #audio models
|
candidates = ["model.diffusion_model.", #ldm/sgm models
|
||||||
unet_key_prefix = "model.model."
|
"model.model.", #audio models
|
||||||
elif "model.double_layers.0.attn.w1q.weight" in state_dict: #aura flow
|
]
|
||||||
unet_key_prefix = "model."
|
counts = {k: 0 for k in candidates}
|
||||||
|
for k in state_dict:
|
||||||
|
for c in candidates:
|
||||||
|
if k.startswith(c):
|
||||||
|
counts[c] += 1
|
||||||
|
break
|
||||||
|
|
||||||
|
top = max(counts, key=counts.get)
|
||||||
|
if counts[top] > 5:
|
||||||
|
return top
|
||||||
else:
|
else:
|
||||||
unet_key_prefix = "model.diffusion_model."
|
return "model." #aura flow and others
|
||||||
return unet_key_prefix
|
|
||||||
|
|
||||||
def convert_config(unet_config):
|
def convert_config(unet_config):
|
||||||
new_config = unet_config.copy()
|
new_config = unet_config.copy()
|
||||||
|
Loading…
Reference in New Issue
Block a user