From 334ba48cea2961994e92c2fb25de9417b19897ed Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 23 Jul 2024 14:13:32 -0400 Subject: [PATCH] More generic unet prefix detection code. --- comfy/model_detection.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index c62e2b82..ae88eeb9 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -261,13 +261,22 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal return model_config def unet_prefix_from_state_dict(state_dict): - if "model.model.postprocess_conv.weight" in state_dict: #audio models - unet_key_prefix = "model.model." - elif "model.double_layers.0.attn.w1q.weight" in state_dict: #aura flow - unet_key_prefix = "model." + candidates = ["model.diffusion_model.", #ldm/sgm models + "model.model.", #audio models + ] + counts = {k: 0 for k in candidates} + for k in state_dict: + for c in candidates: + if k.startswith(c): + counts[c] += 1 + break + + top = max(counts, key=counts.get) + if counts[top] > 5: + return top else: - unet_key_prefix = "model.diffusion_model." - return unet_key_prefix + return "model." #aura flow and others + def convert_config(unet_config): new_config = unet_config.copy()