mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Add loha support.
This commit is contained in:
parent
cc127eeabd
commit
94a7c895f4
@ -16,7 +16,7 @@ This ui will let you design and execute advanced stable diffusion pipelines usin
|
|||||||
- Works even if you don't have a GPU with: ```--cpu``` (slow)
|
- Works even if you don't have a GPU with: ```--cpu``` (slow)
|
||||||
- Can load both ckpt and safetensors models/checkpoints. Standalone VAEs and CLIP models.
|
- Can load both ckpt and safetensors models/checkpoints. Standalone VAEs and CLIP models.
|
||||||
- Embeddings/Textual inversion
|
- Embeddings/Textual inversion
|
||||||
- [Loras (regular and locon)](https://comfyanonymous.github.io/ComfyUI_examples/lora/)
|
- [Loras (regular, locon and loha)](https://comfyanonymous.github.io/ComfyUI_examples/lora/)
|
||||||
- Loading full workflows (with seeds) from generated PNG files.
|
- Loading full workflows (with seeds) from generated PNG files.
|
||||||
- Saving/Loading workflows as Json files.
|
- Saving/Loading workflows as Json files.
|
||||||
- Nodes interface can be used to create complex workflows like one for [Hires fix](https://comfyanonymous.github.io/ComfyUI_examples/2_pass_txt2img/) or much more advanced ones.
|
- Nodes interface can be used to create complex workflows like one for [Hires fix](https://comfyanonymous.github.io/ComfyUI_examples/2_pass_txt2img/) or much more advanced ones.
|
||||||
|
52
comfy/sd.py
52
comfy/sd.py
@ -126,15 +126,17 @@ def load_lora(path, to_load):
|
|||||||
patch_dict = {}
|
patch_dict = {}
|
||||||
loaded_keys = set()
|
loaded_keys = set()
|
||||||
for x in to_load:
|
for x in to_load:
|
||||||
|
alpha_name = "{}.alpha".format(x)
|
||||||
|
alpha = None
|
||||||
|
if alpha_name in lora.keys():
|
||||||
|
alpha = lora[alpha_name].item()
|
||||||
|
loaded_keys.add(alpha_name)
|
||||||
|
|
||||||
A_name = "{}.lora_up.weight".format(x)
|
A_name = "{}.lora_up.weight".format(x)
|
||||||
B_name = "{}.lora_down.weight".format(x)
|
B_name = "{}.lora_down.weight".format(x)
|
||||||
alpha_name = "{}.alpha".format(x)
|
|
||||||
mid_name = "{}.lora_mid.weight".format(x)
|
mid_name = "{}.lora_mid.weight".format(x)
|
||||||
|
|
||||||
if A_name in lora.keys():
|
if A_name in lora.keys():
|
||||||
alpha = None
|
|
||||||
if alpha_name in lora.keys():
|
|
||||||
alpha = lora[alpha_name].item()
|
|
||||||
loaded_keys.add(alpha_name)
|
|
||||||
mid = None
|
mid = None
|
||||||
if mid_name in lora.keys():
|
if mid_name in lora.keys():
|
||||||
mid = lora[mid_name]
|
mid = lora[mid_name]
|
||||||
@ -142,6 +144,18 @@ def load_lora(path, to_load):
|
|||||||
patch_dict[to_load[x]] = (lora[A_name], lora[B_name], alpha, mid)
|
patch_dict[to_load[x]] = (lora[A_name], lora[B_name], alpha, mid)
|
||||||
loaded_keys.add(A_name)
|
loaded_keys.add(A_name)
|
||||||
loaded_keys.add(B_name)
|
loaded_keys.add(B_name)
|
||||||
|
|
||||||
|
hada_w1_a_name = "{}.hada_w1_a".format(x)
|
||||||
|
hada_w1_b_name = "{}.hada_w1_b".format(x)
|
||||||
|
hada_w2_a_name = "{}.hada_w2_a".format(x)
|
||||||
|
hada_w2_b_name = "{}.hada_w2_b".format(x)
|
||||||
|
if hada_w1_a_name in lora.keys():
|
||||||
|
patch_dict[to_load[x]] = (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name])
|
||||||
|
loaded_keys.add(hada_w1_a_name)
|
||||||
|
loaded_keys.add(hada_w1_b_name)
|
||||||
|
loaded_keys.add(hada_w2_a_name)
|
||||||
|
loaded_keys.add(hada_w2_b_name)
|
||||||
|
|
||||||
for x in lora.keys():
|
for x in lora.keys():
|
||||||
if x not in loaded_keys:
|
if x not in loaded_keys:
|
||||||
print("lora key not loaded", x)
|
print("lora key not loaded", x)
|
||||||
@ -280,15 +294,25 @@ class ModelPatcher:
|
|||||||
self.backup[key] = weight.clone()
|
self.backup[key] = weight.clone()
|
||||||
|
|
||||||
alpha = p[0]
|
alpha = p[0]
|
||||||
mat1 = v[0]
|
|
||||||
mat2 = v[1]
|
if len(v) == 4: #lora/locon
|
||||||
if v[2] is not None:
|
mat1 = v[0]
|
||||||
alpha *= v[2] / mat2.shape[0]
|
mat2 = v[1]
|
||||||
if v[3] is not None:
|
if v[2] is not None:
|
||||||
#locon mid weights, hopefully the math is fine because I didn't properly test it
|
alpha *= v[2] / mat2.shape[0]
|
||||||
final_shape = [mat2.shape[1], mat2.shape[0], v[3].shape[2], v[3].shape[3]]
|
if v[3] is not None:
|
||||||
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1).float(), v[3].transpose(0, 1).flatten(start_dim=1).float()).reshape(final_shape).transpose(0, 1)
|
#locon mid weights, hopefully the math is fine because I didn't properly test it
|
||||||
weight += (alpha * torch.mm(mat1.flatten(start_dim=1).float(), mat2.flatten(start_dim=1).float())).reshape(weight.shape).type(weight.dtype).to(weight.device)
|
final_shape = [mat2.shape[1], mat2.shape[0], v[3].shape[2], v[3].shape[3]]
|
||||||
|
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1).float(), v[3].transpose(0, 1).flatten(start_dim=1).float()).reshape(final_shape).transpose(0, 1)
|
||||||
|
weight += (alpha * torch.mm(mat1.flatten(start_dim=1).float(), mat2.flatten(start_dim=1).float())).reshape(weight.shape).type(weight.dtype).to(weight.device)
|
||||||
|
else: #loha
|
||||||
|
w1a = v[0]
|
||||||
|
w1b = v[1]
|
||||||
|
if v[2] is not None:
|
||||||
|
alpha *= v[2] / w1b.shape[0]
|
||||||
|
w2a = v[3]
|
||||||
|
w2b = v[4]
|
||||||
|
weight += (alpha * torch.mm(w1a.float(), w1b.float()) * torch.mm(w2a.float(), w2b.float())).reshape(weight.shape).type(weight.dtype).to(weight.device)
|
||||||
return self.model
|
return self.model
|
||||||
def unpatch_model(self):
|
def unpatch_model(self):
|
||||||
model_sd = self.model.state_dict()
|
model_sd = self.model.state_dict()
|
||||||
|
Loading…
Reference in New Issue
Block a user