diff --git a/comfy/sd.py b/comfy/sd.py index 6d1e8bb9..b8f82966 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -129,12 +129,17 @@ def load_lora(path, to_load): A_name = "{}.lora_up.weight".format(x) B_name = "{}.lora_down.weight".format(x) alpha_name = "{}.alpha".format(x) + mid_name = "{}.lora_mid.weight".format(x) if A_name in lora.keys(): alpha = None if alpha_name in lora.keys(): alpha = lora[alpha_name].item() loaded_keys.add(alpha_name) - patch_dict[to_load[x]] = (lora[A_name], lora[B_name], alpha) + mid = None + if mid_name in lora.keys(): + mid = lora[mid_name] + loaded_keys.add(mid_name) + patch_dict[to_load[x]] = (lora[A_name], lora[B_name], alpha, mid) loaded_keys.add(A_name) loaded_keys.add(B_name) for x in lora.keys(): @@ -279,6 +284,10 @@ class ModelPatcher: mat2 = v[1] if v[2] is not None: alpha *= v[2] / mat2.shape[0] + if v[3] is not None: + #locon mid weights, hopefully the math is fine because I didn't properly test it + final_shape = [mat2.shape[1], mat2.shape[0], v[3].shape[2], v[3].shape[3]] + mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1).float(), v[3].transpose(0, 1).flatten(start_dim=1).float()).reshape(final_shape).transpose(0, 1) weight += (alpha * torch.mm(mat1.flatten(start_dim=1).float(), mat2.flatten(start_dim=1).float())).reshape(weight.shape).type(weight.dtype).to(weight.device) return self.model def unpatch_model(self):