mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Add support for locon mid weights.
This commit is contained in:
parent
451447bd9f
commit
cc309568e1
11
comfy/sd.py
11
comfy/sd.py
@ -129,12 +129,17 @@ def load_lora(path, to_load):
|
|||||||
A_name = "{}.lora_up.weight".format(x)
|
A_name = "{}.lora_up.weight".format(x)
|
||||||
B_name = "{}.lora_down.weight".format(x)
|
B_name = "{}.lora_down.weight".format(x)
|
||||||
alpha_name = "{}.alpha".format(x)
|
alpha_name = "{}.alpha".format(x)
|
||||||
|
mid_name = "{}.lora_mid.weight".format(x)
|
||||||
if A_name in lora.keys():
|
if A_name in lora.keys():
|
||||||
alpha = None
|
alpha = None
|
||||||
if alpha_name in lora.keys():
|
if alpha_name in lora.keys():
|
||||||
alpha = lora[alpha_name].item()
|
alpha = lora[alpha_name].item()
|
||||||
loaded_keys.add(alpha_name)
|
loaded_keys.add(alpha_name)
|
||||||
patch_dict[to_load[x]] = (lora[A_name], lora[B_name], alpha)
|
mid = None
|
||||||
|
if mid_name in lora.keys():
|
||||||
|
mid = lora[mid_name]
|
||||||
|
loaded_keys.add(mid_name)
|
||||||
|
patch_dict[to_load[x]] = (lora[A_name], lora[B_name], alpha, mid)
|
||||||
loaded_keys.add(A_name)
|
loaded_keys.add(A_name)
|
||||||
loaded_keys.add(B_name)
|
loaded_keys.add(B_name)
|
||||||
for x in lora.keys():
|
for x in lora.keys():
|
||||||
@ -279,6 +284,10 @@ class ModelPatcher:
|
|||||||
mat2 = v[1]
|
mat2 = v[1]
|
||||||
if v[2] is not None:
|
if v[2] is not None:
|
||||||
alpha *= v[2] / mat2.shape[0]
|
alpha *= v[2] / mat2.shape[0]
|
||||||
|
if v[3] is not None:
|
||||||
|
#locon mid weights, hopefully the math is fine because I didn't properly test it
|
||||||
|
final_shape = [mat2.shape[1], mat2.shape[0], v[3].shape[2], v[3].shape[3]]
|
||||||
|
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1).float(), v[3].transpose(0, 1).flatten(start_dim=1).float()).reshape(final_shape).transpose(0, 1)
|
||||||
weight += (alpha * torch.mm(mat1.flatten(start_dim=1).float(), mat2.flatten(start_dim=1).float())).reshape(weight.shape).type(weight.dtype).to(weight.device)
|
weight += (alpha * torch.mm(mat1.flatten(start_dim=1).float(), mat2.flatten(start_dim=1).float())).reshape(weight.shape).type(weight.dtype).to(weight.device)
|
||||||
return self.model
|
return self.model
|
||||||
def unpatch_model(self):
|
def unpatch_model(self):
|
||||||
|
Loading…
Reference in New Issue
Block a user