diff --git a/comfy/weight_adapter/base.py b/comfy/weight_adapter/base.py index 54af3bab..29873519 100644 --- a/comfy/weight_adapter/base.py +++ b/comfy/weight_adapter/base.py @@ -43,13 +43,23 @@ def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediat dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, intermediate_dtype) lora_diff *= alpha weight_calc = weight + function(lora_diff).type(weight.dtype) - weight_norm = ( - weight_calc.transpose(0, 1) - .reshape(weight_calc.shape[1], -1) - .norm(dim=1, keepdim=True) - .reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1)) - .transpose(0, 1) - ) + + wd_on_output_axis = dora_scale.shape[0] == weight_calc.shape[0] + if wd_on_output_axis: + weight_norm = ( + weight.reshape(weight.shape[0], -1) + .norm(dim=1, keepdim=True) + .reshape(weight.shape[0], *[1] * (weight.dim() - 1)) + ) + else: + weight_norm = ( + weight_calc.transpose(0, 1) + .reshape(weight_calc.shape[1], -1) + .norm(dim=1, keepdim=True) + .reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1)) + .transpose(0, 1) + ) + weight_norm = weight_norm + torch.finfo(weight.dtype).eps weight_calc *= (dora_scale / weight_norm).type(weight.dtype) if strength != 1.0: