From 203942c8b29dfbf59a7976dcee29e8ab44a1b32d Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 8 Oct 2024 19:02:41 -0400 Subject: [PATCH] Fix flux doras with diffusers keys. --- comfy/lora.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/comfy/lora.py b/comfy/lora.py index 2d0ed7bf..80057cdd 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -343,10 +343,10 @@ def model_lora_keys_unet(model, key_map={}): return key_map -def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype): +def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function): dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, intermediate_dtype) lora_diff *= alpha - weight_calc = weight + lora_diff.type(weight.dtype) + weight_calc = weight + function(lora_diff).type(weight.dtype) weight_norm = ( weight_calc.transpose(0, 1) .reshape(weight_calc.shape[1], -1) @@ -453,7 +453,7 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32): try: lora_diff = torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)).reshape(weight.shape) if dora_scale is not None: - weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype)) + weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function) else: weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) except Exception as e: @@ -499,7 +499,7 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32): try: lora_diff = torch.kron(w1, w2).reshape(weight.shape) if dora_scale is not None: - weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype)) + weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function) else: weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) except Exception as e: @@ -536,7 +536,7 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32): try: lora_diff = (m1 * m2).reshape(weight.shape) if dora_scale is not None: - weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype)) + weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function) else: weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) except Exception as e: @@ -577,7 +577,7 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32): lora_diff += torch.mm(b1, b2).reshape(weight.shape) if dora_scale is not None: - weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype)) + weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function) else: weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) except Exception as e: