mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-20 11:23:29 +00:00
Support OneTrainer SD3 lora format.
This commit is contained in:
parent
4ef1479dcd
commit
2f360ae898
@ -218,12 +218,21 @@ def model_lora_keys_clip(model, key_map={}):
|
|||||||
lora_key = "lora_prior_te_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #cascade lora: TODO put lora key prefix in the model config
|
lora_key = "lora_prior_te_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #cascade lora: TODO put lora key prefix in the model config
|
||||||
key_map[lora_key] = k
|
key_map[lora_key] = k
|
||||||
|
|
||||||
|
for k in sdk: #OneTrainer SD3 lora
|
||||||
|
if k.startswith("t5xxl.transformer.") and k.endswith(".weight"):
|
||||||
|
l_key = k[len("t5xxl.transformer."):-len(".weight")]
|
||||||
|
lora_key = "lora_te3_{}".format(l_key.replace(".", "_"))
|
||||||
|
key_map[lora_key] = k
|
||||||
|
|
||||||
k = "clip_g.transformer.text_projection.weight"
|
k = "clip_g.transformer.text_projection.weight"
|
||||||
if k in sdk:
|
if k in sdk:
|
||||||
key_map["lora_prior_te_text_projection"] = k #cascade lora?
|
key_map["lora_prior_te_text_projection"] = k #cascade lora?
|
||||||
# key_map["text_encoder.text_projection"] = k #TODO: check if other lora have the text_projection too
|
# key_map["text_encoder.text_projection"] = k #TODO: check if other lora have the text_projection too
|
||||||
# key_map["lora_te_text_projection"] = k
|
key_map["lora_te2_text_projection"] = k #OneTrainer SD3 lora
|
||||||
|
|
||||||
|
k = "clip_l.transformer.text_projection.weight"
|
||||||
|
if k in sdk:
|
||||||
|
key_map["lora_te1_text_projection"] = k #OneTrainer SD3 lora, not necessary but omits warning
|
||||||
|
|
||||||
return key_map
|
return key_map
|
||||||
|
|
||||||
@ -262,4 +271,7 @@ def model_lora_keys_unet(model, key_map={}):
|
|||||||
key_lora = "base_model.model.{}".format(k[:-len(".weight")]) #format for flash-sd3 lora and others?
|
key_lora = "base_model.model.{}".format(k[:-len(".weight")]) #format for flash-sd3 lora and others?
|
||||||
key_map[key_lora] = to
|
key_map[key_lora] = to
|
||||||
|
|
||||||
|
key_lora = "lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_")) #OneTrainer lora
|
||||||
|
key_map[key_lora] = to
|
||||||
|
|
||||||
return key_map
|
return key_map
|
||||||
|
Loading…
Reference in New Issue
Block a user