Add loha support.

This commit is contained in:
comfyanonymous 2023-03-23 03:40:12 -04:00
parent cc127eeabd
commit 94a7c895f4
2 changed files with 39 additions and 15 deletions

View File

@ -16,7 +16,7 @@ This ui will let you design and execute advanced stable diffusion pipelines usin
- Works even if you don't have a GPU with: ```--cpu``` (slow) - Works even if you don't have a GPU with: ```--cpu``` (slow)
- Can load both ckpt and safetensors models/checkpoints. Standalone VAEs and CLIP models. - Can load both ckpt and safetensors models/checkpoints. Standalone VAEs and CLIP models.
- Embeddings/Textual inversion - Embeddings/Textual inversion
- [Loras (regular and locon)](https://comfyanonymous.github.io/ComfyUI_examples/lora/) - [Loras (regular, locon and loha)](https://comfyanonymous.github.io/ComfyUI_examples/lora/)
- Loading full workflows (with seeds) from generated PNG files. - Loading full workflows (with seeds) from generated PNG files.
- Saving/Loading workflows as Json files. - Saving/Loading workflows as Json files.
- Nodes interface can be used to create complex workflows like one for [Hires fix](https://comfyanonymous.github.io/ComfyUI_examples/2_pass_txt2img/) or much more advanced ones. - Nodes interface can be used to create complex workflows like one for [Hires fix](https://comfyanonymous.github.io/ComfyUI_examples/2_pass_txt2img/) or much more advanced ones.

View File

@ -126,15 +126,17 @@ def load_lora(path, to_load):
patch_dict = {} patch_dict = {}
loaded_keys = set() loaded_keys = set()
for x in to_load: for x in to_load:
A_name = "{}.lora_up.weight".format(x)
B_name = "{}.lora_down.weight".format(x)
alpha_name = "{}.alpha".format(x) alpha_name = "{}.alpha".format(x)
mid_name = "{}.lora_mid.weight".format(x)
if A_name in lora.keys():
alpha = None alpha = None
if alpha_name in lora.keys(): if alpha_name in lora.keys():
alpha = lora[alpha_name].item() alpha = lora[alpha_name].item()
loaded_keys.add(alpha_name) loaded_keys.add(alpha_name)
A_name = "{}.lora_up.weight".format(x)
B_name = "{}.lora_down.weight".format(x)
mid_name = "{}.lora_mid.weight".format(x)
if A_name in lora.keys():
mid = None mid = None
if mid_name in lora.keys(): if mid_name in lora.keys():
mid = lora[mid_name] mid = lora[mid_name]
@ -142,6 +144,18 @@ def load_lora(path, to_load):
patch_dict[to_load[x]] = (lora[A_name], lora[B_name], alpha, mid) patch_dict[to_load[x]] = (lora[A_name], lora[B_name], alpha, mid)
loaded_keys.add(A_name) loaded_keys.add(A_name)
loaded_keys.add(B_name) loaded_keys.add(B_name)
hada_w1_a_name = "{}.hada_w1_a".format(x)
hada_w1_b_name = "{}.hada_w1_b".format(x)
hada_w2_a_name = "{}.hada_w2_a".format(x)
hada_w2_b_name = "{}.hada_w2_b".format(x)
if hada_w1_a_name in lora.keys():
patch_dict[to_load[x]] = (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name])
loaded_keys.add(hada_w1_a_name)
loaded_keys.add(hada_w1_b_name)
loaded_keys.add(hada_w2_a_name)
loaded_keys.add(hada_w2_b_name)
for x in lora.keys(): for x in lora.keys():
if x not in loaded_keys: if x not in loaded_keys:
print("lora key not loaded", x) print("lora key not loaded", x)
@ -280,6 +294,8 @@ class ModelPatcher:
self.backup[key] = weight.clone() self.backup[key] = weight.clone()
alpha = p[0] alpha = p[0]
if len(v) == 4: #lora/locon
mat1 = v[0] mat1 = v[0]
mat2 = v[1] mat2 = v[1]
if v[2] is not None: if v[2] is not None:
@ -289,6 +305,14 @@ class ModelPatcher:
final_shape = [mat2.shape[1], mat2.shape[0], v[3].shape[2], v[3].shape[3]] final_shape = [mat2.shape[1], mat2.shape[0], v[3].shape[2], v[3].shape[3]]
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1).float(), v[3].transpose(0, 1).flatten(start_dim=1).float()).reshape(final_shape).transpose(0, 1) mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1).float(), v[3].transpose(0, 1).flatten(start_dim=1).float()).reshape(final_shape).transpose(0, 1)
weight += (alpha * torch.mm(mat1.flatten(start_dim=1).float(), mat2.flatten(start_dim=1).float())).reshape(weight.shape).type(weight.dtype).to(weight.device) weight += (alpha * torch.mm(mat1.flatten(start_dim=1).float(), mat2.flatten(start_dim=1).float())).reshape(weight.shape).type(weight.dtype).to(weight.device)
else: #loha
w1a = v[0]
w1b = v[1]
if v[2] is not None:
alpha *= v[2] / w1b.shape[0]
w2a = v[3]
w2b = v[4]
weight += (alpha * torch.mm(w1a.float(), w1b.float()) * torch.mm(w2a.float(), w2b.float())).reshape(weight.shape).type(weight.dtype).to(weight.device)
return self.model return self.model
def unpatch_model(self): def unpatch_model(self):
model_sd = self.model.state_dict() model_sd = self.model.state_dict()