mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 02:15:17 +00:00
Add support for locon mid weights.
This commit is contained in:
parent
451447bd9f
commit
cc309568e1
11
comfy/sd.py
11
comfy/sd.py
@ -129,12 +129,17 @@ def load_lora(path, to_load):
|
||||
A_name = "{}.lora_up.weight".format(x)
|
||||
B_name = "{}.lora_down.weight".format(x)
|
||||
alpha_name = "{}.alpha".format(x)
|
||||
mid_name = "{}.lora_mid.weight".format(x)
|
||||
if A_name in lora.keys():
|
||||
alpha = None
|
||||
if alpha_name in lora.keys():
|
||||
alpha = lora[alpha_name].item()
|
||||
loaded_keys.add(alpha_name)
|
||||
patch_dict[to_load[x]] = (lora[A_name], lora[B_name], alpha)
|
||||
mid = None
|
||||
if mid_name in lora.keys():
|
||||
mid = lora[mid_name]
|
||||
loaded_keys.add(mid_name)
|
||||
patch_dict[to_load[x]] = (lora[A_name], lora[B_name], alpha, mid)
|
||||
loaded_keys.add(A_name)
|
||||
loaded_keys.add(B_name)
|
||||
for x in lora.keys():
|
||||
@ -279,6 +284,10 @@ class ModelPatcher:
|
||||
mat2 = v[1]
|
||||
if v[2] is not None:
|
||||
alpha *= v[2] / mat2.shape[0]
|
||||
if v[3] is not None:
|
||||
#locon mid weights, hopefully the math is fine because I didn't properly test it
|
||||
final_shape = [mat2.shape[1], mat2.shape[0], v[3].shape[2], v[3].shape[3]]
|
||||
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1).float(), v[3].transpose(0, 1).flatten(start_dim=1).float()).reshape(final_shape).transpose(0, 1)
|
||||
weight += (alpha * torch.mm(mat1.flatten(start_dim=1).float(), mat2.flatten(start_dim=1).float())).reshape(weight.shape).type(weight.dtype).to(weight.device)
|
||||
return self.model
|
||||
def unpatch_model(self):
|
||||
|
Loading…
Reference in New Issue
Block a user