diff --git a/comfy/sd.py b/comfy/sd.py index 92dbb931..3eb50cc9 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -111,6 +111,8 @@ def load_lora(path, to_load): loaded_keys.add(A_name) loaded_keys.add(B_name) + + ######## loha hada_w1_a_name = "{}.hada_w1_a".format(x) hada_w1_b_name = "{}.hada_w1_b".format(x) hada_w2_a_name = "{}.hada_w2_a".format(x) @@ -132,6 +134,54 @@ def load_lora(path, to_load): loaded_keys.add(hada_w2_a_name) loaded_keys.add(hada_w2_b_name) + + ######## lokr + lokr_w1_name = "{}.lokr_w1".format(x) + lokr_w2_name = "{}.lokr_w2".format(x) + lokr_w1_a_name = "{}.lokr_w1_a".format(x) + lokr_w1_b_name = "{}.lokr_w1_b".format(x) + lokr_t2_name = "{}.lokr_t2".format(x) + lokr_w2_a_name = "{}.lokr_w2_a".format(x) + lokr_w2_b_name = "{}.lokr_w2_b".format(x) + + lokr_w1 = None + if lokr_w1_name in lora.keys(): + lokr_w1 = lora[lokr_w1_name] + loaded_keys.add(lokr_w1_name) + + lokr_w2 = None + if lokr_w2_name in lora.keys(): + lokr_w2 = lora[lokr_w2_name] + loaded_keys.add(lokr_w2_name) + + lokr_w1_a = None + if lokr_w1_a_name in lora.keys(): + lokr_w1_a = lora[lokr_w1_a_name] + loaded_keys.add(lokr_w1_a_name) + + lokr_w1_b = None + if lokr_w1_b_name in lora.keys(): + lokr_w1_b = lora[lokr_w1_b_name] + loaded_keys.add(lokr_w1_b_name) + + lokr_w2_a = None + if lokr_w2_a_name in lora.keys(): + lokr_w2_a = lora[lokr_w2_a_name] + loaded_keys.add(lokr_w2_a_name) + + lokr_w2_b = None + if lokr_w2_b_name in lora.keys(): + lokr_w2_b = lora[lokr_w2_b_name] + loaded_keys.add(lokr_w2_b_name) + + lokr_t2 = None + if lokr_t2_name in lora.keys(): + lokr_t2 = lora[lokr_t2_name] + loaded_keys.add(lokr_t2_name) + + if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None): + patch_dict[to_load[x]] = (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2) + for x in lora.keys(): if x not in loaded_keys: print("lora key not loaded", x) @@ -315,6 +365,33 @@ class ModelPatcher: final_shape = [mat2.shape[1], mat2.shape[0], v[3].shape[2], v[3].shape[3]] mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1).float(), v[3].transpose(0, 1).flatten(start_dim=1).float()).reshape(final_shape).transpose(0, 1) weight += (alpha * torch.mm(mat1.flatten(start_dim=1).float(), mat2.flatten(start_dim=1).float())).reshape(weight.shape).type(weight.dtype).to(weight.device) + elif len(v) == 8: #lokr + w1 = v[0] + w2 = v[1] + w1_a = v[3] + w1_b = v[4] + w2_a = v[5] + w2_b = v[6] + t2 = v[7] + dim = None + + if w1 is None: + dim = w1_b.shape[0] + w1 = torch.mm(w1_a.float(), w1_b.float()) + + if w2 is None: + dim = w2_b.shape[0] + if t2 is None: + w2 = torch.mm(w2_a.float(), w2_b.float()) + else: + w2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.float(), w2_b.float(), w2_a.float()) + + if len(w2.shape) == 4: + w1 = w1.unsqueeze(2).unsqueeze(2) + if v[2] is not None and dim is not None: + alpha *= v[2] / dim + + weight += alpha * torch.kron(w1.float(), w2.float()).reshape(weight.shape).type(weight.dtype).to(weight.device) else: #loha w1a = v[0] w1b = v[1]