From a9f04edc5887095f312bc16d9a6617e08c764678 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 9 Aug 2024 03:21:10 -0400 Subject: [PATCH] Implement text encoder part of HunyuanDiT loras. --- comfy/lora.py | 34 +++++++++++++++++++++++++++++----- 1 file changed, 29 insertions(+), 5 deletions(-) diff --git a/comfy/lora.py b/comfy/lora.py index 0a38021c..3b8b6c16 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -1,3 +1,21 @@ +""" + This file is part of ComfyUI. + Copyright (C) 2024 Comfy + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . +""" + import comfy.utils import logging @@ -218,11 +236,17 @@ def model_lora_keys_clip(model, key_map={}): lora_key = "lora_prior_te_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #cascade lora: TODO put lora key prefix in the model config key_map[lora_key] = k - for k in sdk: #OneTrainer SD3 lora - if k.startswith("t5xxl.transformer.") and k.endswith(".weight"): - l_key = k[len("t5xxl.transformer."):-len(".weight")] - lora_key = "lora_te3_{}".format(l_key.replace(".", "_")) - key_map[lora_key] = k + for k in sdk: + if k.endswith(".weight"): + if k.startswith("t5xxl.transformer."):#OneTrainer SD3 lora + l_key = k[len("t5xxl.transformer."):-len(".weight")] + lora_key = "lora_te3_{}".format(l_key.replace(".", "_")) + key_map[lora_key] = k + elif k.startswith("hydit_clip.transformer.bert."): #HunyuanDiT Lora + l_key = k[len("hydit_clip.transformer.bert."):-len(".weight")] + lora_key = "lora_te1_{}".format(l_key.replace(".", "_")) + key_map[lora_key] = k + k = "clip_g.transformer.text_projection.weight" if k in sdk: