mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Support AuraFlow Lora and loading model weights in diffusers format.
You can load model weights in diffusers format using the UNETLoader node.
This commit is contained in:
parent
ce2473bb01
commit
a3dffc447a
@ -274,4 +274,12 @@ def model_lora_keys_unet(model, key_map={}):
|
||||
key_lora = "lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_")) #OneTrainer lora
|
||||
key_map[key_lora] = to
|
||||
|
||||
if isinstance(model, comfy.model_base.AuraFlow): #Diffusers lora AuraFlow
|
||||
diffusers_keys = comfy.utils.auraflow_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
|
||||
for k in diffusers_keys:
|
||||
if k.endswith(".weight"):
|
||||
to = diffusers_keys[k]
|
||||
key_lora = "transformer.{}".format(k[:-len(".weight")]) #simpletrainer and probably regular diffusers lora format
|
||||
key_map[key_lora] = to
|
||||
|
||||
return key_map
|
||||
|
@ -109,6 +109,10 @@ def detect_unet_config(state_dict, key_prefix):
|
||||
unet_config = {}
|
||||
unet_config["max_seq"] = state_dict['{}positional_encoding'.format(key_prefix)].shape[1]
|
||||
unet_config["cond_seq_dim"] = state_dict['{}cond_seq_linear.weight'.format(key_prefix)].shape[1]
|
||||
double_layers = count_blocks(state_dict_keys, '{}double_layers.'.format(key_prefix) + '{}.')
|
||||
single_layers = count_blocks(state_dict_keys, '{}single_layers.'.format(key_prefix) + '{}.')
|
||||
unet_config["n_double_layers"] = double_layers
|
||||
unet_config["n_layers"] = double_layers + single_layers
|
||||
return unet_config
|
||||
|
||||
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
|
||||
@ -450,37 +454,45 @@ def model_config_from_diffusers_unet(state_dict):
|
||||
return None
|
||||
|
||||
def convert_diffusers_mmdit(state_dict, output_prefix=""):
|
||||
num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.')
|
||||
if num_blocks > 0:
|
||||
out_sd = {}
|
||||
|
||||
if 'transformer_blocks.0.attn.add_q_proj.weight' in state_dict: #SD3
|
||||
num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.')
|
||||
depth = state_dict["pos_embed.proj.weight"].shape[0] // 64
|
||||
out_sd = {}
|
||||
sd_map = comfy.utils.mmdit_to_diffusers({"depth": depth, "num_blocks": num_blocks}, output_prefix=output_prefix)
|
||||
for k in sd_map:
|
||||
weight = state_dict.get(k, None)
|
||||
if weight is not None:
|
||||
t = sd_map[k]
|
||||
elif 'joint_transformer_blocks.0.attn.add_k_proj.weight' in state_dict: #AuraFlow
|
||||
num_joint = count_blocks(state_dict, 'joint_transformer_blocks.{}.')
|
||||
num_single = count_blocks(state_dict, 'single_transformer_blocks.{}.')
|
||||
sd_map = comfy.utils.auraflow_to_diffusers({"n_double_layers": num_joint, "n_layers": num_joint + num_single}, output_prefix=output_prefix)
|
||||
else:
|
||||
return None
|
||||
|
||||
if not isinstance(t, str):
|
||||
if len(t) > 2:
|
||||
fun = t[2]
|
||||
else:
|
||||
fun = lambda a: a
|
||||
offset = t[1]
|
||||
if offset is not None:
|
||||
old_weight = out_sd.get(t[0], None)
|
||||
if old_weight is None:
|
||||
old_weight = torch.empty_like(weight)
|
||||
old_weight = old_weight.repeat([3] + [1] * (len(old_weight.shape) - 1))
|
||||
for k in sd_map:
|
||||
weight = state_dict.get(k, None)
|
||||
if weight is not None:
|
||||
t = sd_map[k]
|
||||
|
||||
w = old_weight.narrow(offset[0], offset[1], offset[2])
|
||||
else:
|
||||
old_weight = weight
|
||||
w = weight
|
||||
w[:] = fun(weight)
|
||||
t = t[0]
|
||||
out_sd[t] = old_weight
|
||||
if not isinstance(t, str):
|
||||
if len(t) > 2:
|
||||
fun = t[2]
|
||||
else:
|
||||
out_sd[t] = weight
|
||||
state_dict.pop(k)
|
||||
fun = lambda a: a
|
||||
offset = t[1]
|
||||
if offset is not None:
|
||||
old_weight = out_sd.get(t[0], None)
|
||||
if old_weight is None:
|
||||
old_weight = torch.empty_like(weight)
|
||||
old_weight = old_weight.repeat([3] + [1] * (len(old_weight.shape) - 1))
|
||||
|
||||
w = old_weight.narrow(offset[0], offset[1], offset[2])
|
||||
else:
|
||||
old_weight = weight
|
||||
w = weight
|
||||
w[:] = fun(weight)
|
||||
t = t[0]
|
||||
out_sd[t] = old_weight
|
||||
else:
|
||||
out_sd[t] = weight
|
||||
state_dict.pop(k)
|
||||
|
||||
return out_sd
|
||||
|
33
comfy/sd.py
33
comfy/sd.py
@ -562,26 +562,25 @@ def load_unet_state_dict(sd): #load unet in diffusers or regular format
|
||||
|
||||
if model_config is not None:
|
||||
new_sd = sd
|
||||
elif 'transformer_blocks.0.attn.add_q_proj.weight' in sd: #MMDIT SD3
|
||||
else:
|
||||
new_sd = model_detection.convert_diffusers_mmdit(sd, "")
|
||||
if new_sd is None:
|
||||
return None
|
||||
model_config = model_detection.model_config_from_unet(new_sd, "")
|
||||
if model_config is None:
|
||||
return None
|
||||
else: #diffusers
|
||||
model_config = model_detection.model_config_from_diffusers_unet(sd)
|
||||
if model_config is None:
|
||||
return None
|
||||
if new_sd is not None: #diffusers mmdit
|
||||
model_config = model_detection.model_config_from_unet(new_sd, "")
|
||||
if model_config is None:
|
||||
return None
|
||||
else: #diffusers unet
|
||||
model_config = model_detection.model_config_from_diffusers_unet(sd)
|
||||
if model_config is None:
|
||||
return None
|
||||
|
||||
diffusers_keys = comfy.utils.unet_to_diffusers(model_config.unet_config)
|
||||
diffusers_keys = comfy.utils.unet_to_diffusers(model_config.unet_config)
|
||||
|
||||
new_sd = {}
|
||||
for k in diffusers_keys:
|
||||
if k in sd:
|
||||
new_sd[diffusers_keys[k]] = sd.pop(k)
|
||||
else:
|
||||
logging.warning("{} {}".format(diffusers_keys[k], k))
|
||||
new_sd = {}
|
||||
for k in diffusers_keys:
|
||||
if k in sd:
|
||||
new_sd[diffusers_keys[k]] = sd.pop(k)
|
||||
else:
|
||||
logging.warning("{} {}".format(diffusers_keys[k], k))
|
||||
|
||||
offload_device = model_management.unet_offload_device()
|
||||
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=model_config.supported_inference_dtypes)
|
||||
|
@ -332,6 +332,76 @@ def mmdit_to_diffusers(mmdit_config, output_prefix=""):
|
||||
|
||||
return key_map
|
||||
|
||||
|
||||
def auraflow_to_diffusers(mmdit_config, output_prefix=""):
|
||||
n_double_layers = mmdit_config.get("n_double_layers", 0)
|
||||
n_layers = mmdit_config.get("n_layers", 0)
|
||||
|
||||
key_map = {}
|
||||
for i in range(n_layers):
|
||||
if i < n_double_layers:
|
||||
index = i
|
||||
prefix_from = "joint_transformer_blocks"
|
||||
prefix_to = "{}double_layers".format(output_prefix)
|
||||
block_map = {
|
||||
"attn.to_q.weight": "attn.w2q.weight",
|
||||
"attn.to_k.weight": "attn.w2k.weight",
|
||||
"attn.to_v.weight": "attn.w2v.weight",
|
||||
"attn.to_out.0.weight": "attn.w2o.weight",
|
||||
"attn.add_q_proj.weight": "attn.w1q.weight",
|
||||
"attn.add_k_proj.weight": "attn.w1k.weight",
|
||||
"attn.add_v_proj.weight": "attn.w1v.weight",
|
||||
"attn.to_add_out.weight": "attn.w1o.weight",
|
||||
"ff.linear_1.weight": "mlpX.c_fc1.weight",
|
||||
"ff.linear_2.weight": "mlpX.c_fc2.weight",
|
||||
"ff.out_projection.weight": "mlpX.c_proj.weight",
|
||||
"ff_context.linear_1.weight": "mlpC.c_fc1.weight",
|
||||
"ff_context.linear_2.weight": "mlpC.c_fc2.weight",
|
||||
"ff_context.out_projection.weight": "mlpC.c_proj.weight",
|
||||
"norm1.linear.weight": "modX.1.weight",
|
||||
"norm1_context.linear.weight": "modC.1.weight",
|
||||
}
|
||||
else:
|
||||
index = i - n_double_layers
|
||||
prefix_from = "single_transformer_blocks"
|
||||
prefix_to = "{}single_layers".format(output_prefix)
|
||||
|
||||
block_map = {
|
||||
"attn.to_q.weight": "attn.w1q.weight",
|
||||
"attn.to_k.weight": "attn.w1k.weight",
|
||||
"attn.to_v.weight": "attn.w1v.weight",
|
||||
"attn.to_out.0.weight": "attn.w1o.weight",
|
||||
"norm1.linear.weight": "modCX.1.weight",
|
||||
"ff.linear_1.weight": "mlp.c_fc1.weight",
|
||||
"ff.linear_2.weight": "mlp.c_fc2.weight",
|
||||
"ff.out_projection.weight": "mlp.c_proj.weight"
|
||||
}
|
||||
|
||||
for k in block_map:
|
||||
key_map["{}.{}.{}".format(prefix_from, index, k)] = "{}.{}.{}".format(prefix_to, index, block_map[k])
|
||||
|
||||
MAP_BASIC = {
|
||||
("positional_encoding", "pos_embed.pos_embed"),
|
||||
("register_tokens", "register_tokens"),
|
||||
("t_embedder.mlp.0.weight", "time_step_proj.linear_1.weight"),
|
||||
("t_embedder.mlp.0.bias", "time_step_proj.linear_1.bias"),
|
||||
("t_embedder.mlp.2.weight", "time_step_proj.linear_2.weight"),
|
||||
("t_embedder.mlp.2.bias", "time_step_proj.linear_2.bias"),
|
||||
("cond_seq_linear.weight", "context_embedder.weight"),
|
||||
("init_x_linear.weight", "pos_embed.proj.weight"),
|
||||
("init_x_linear.bias", "pos_embed.proj.bias"),
|
||||
("final_linear.weight", "proj_out.weight"),
|
||||
("modF.1.weight", "norm_out.linear.weight", swap_scale_shift),
|
||||
}
|
||||
|
||||
for k in MAP_BASIC:
|
||||
if len(k) > 2:
|
||||
key_map[k[1]] = ("{}{}".format(output_prefix, k[0]), None, k[2])
|
||||
else:
|
||||
key_map[k[1]] = "{}{}".format(output_prefix, k[0])
|
||||
|
||||
return key_map
|
||||
|
||||
def repeat_to_batch_size(tensor, batch_size, dim=0):
|
||||
if tensor.shape[dim] > batch_size:
|
||||
return tensor.narrow(dim, 0, batch_size)
|
||||
|
Loading…
Reference in New Issue
Block a user