Correct DoRA implementation

This commit is contained in:
catboxanon 2025-04-14 12:17:15 -04:00
parent a14c2fc356
commit facf574a61

View File

@ -413,11 +413,10 @@ def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediat
lora_diff *= alpha
weight_calc = weight + function(lora_diff).type(weight.dtype)
weight_norm = (
weight_calc.transpose(0, 1)
.reshape(weight_calc.shape[1], -1)
weight_calc
.reshape(weight_calc.shape[0], -1)
.norm(dim=1, keepdim=True)
.reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1))
.transpose(0, 1)
.reshape(weight_calc.shape[0], *[1] * (weight_calc.dim() - 1))
)
weight_calc *= (dora_scale / weight_norm).type(weight.dtype)