diff --git a/comfy/lora.py b/comfy/lora.py index 9e3e7a8b..03743177 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -218,12 +218,21 @@ def model_lora_keys_clip(model, key_map={}): lora_key = "lora_prior_te_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #cascade lora: TODO put lora key prefix in the model config key_map[lora_key] = k + for k in sdk: #OneTrainer SD3 lora + if k.startswith("t5xxl.transformer.") and k.endswith(".weight"): + l_key = k[len("t5xxl.transformer."):-len(".weight")] + lora_key = "lora_te3_{}".format(l_key.replace(".", "_")) + key_map[lora_key] = k k = "clip_g.transformer.text_projection.weight" if k in sdk: key_map["lora_prior_te_text_projection"] = k #cascade lora? # key_map["text_encoder.text_projection"] = k #TODO: check if other lora have the text_projection too - # key_map["lora_te_text_projection"] = k + key_map["lora_te2_text_projection"] = k #OneTrainer SD3 lora + + k = "clip_l.transformer.text_projection.weight" + if k in sdk: + key_map["lora_te1_text_projection"] = k #OneTrainer SD3 lora, not necessary but omits warning return key_map @@ -262,4 +271,7 @@ def model_lora_keys_unet(model, key_map={}): key_lora = "base_model.model.{}".format(k[:-len(".weight")]) #format for flash-sd3 lora and others? key_map[key_lora] = to + key_lora = "lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_")) #OneTrainer lora + key_map[key_lora] = to + return key_map