diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index c397ee51..78982d79 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -9,7 +9,7 @@ import comfy.model_management from comfy.types import UnetWrapperFunction -def apply_weight_decompose(dora_scale, weight): +def weight_decompose_scale(dora_scale, weight): weight_norm = ( weight.transpose(0, 1) .reshape(weight.shape[1], -1) @@ -18,7 +18,7 @@ def apply_weight_decompose(dora_scale, weight): .transpose(0, 1) ) - return weight * (dora_scale / weight_norm).type(weight.dtype) + return (dora_scale / weight_norm).type(weight.dtype) def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None): to = model_options["transformer_options"].copy() @@ -365,7 +365,7 @@ class ModelPatcher: try: weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype) if dora_scale is not None: - weight = apply_weight_decompose(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight) + weight *= weight_decompose_scale(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight) except Exception as e: logging.error("ERROR {} {} {}".format(patch_type, key, e)) elif patch_type == "lokr": @@ -407,7 +407,7 @@ class ModelPatcher: try: weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype) if dora_scale is not None: - weight = apply_weight_decompose(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight) + weight *= weight_decompose_scale(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight) except Exception as e: logging.error("ERROR {} {} {}".format(patch_type, key, e)) elif patch_type == "loha": @@ -439,7 +439,7 @@ class ModelPatcher: try: weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype) if dora_scale is not None: - weight = apply_weight_decompose(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight) + weight *= weight_decompose_scale(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight) except Exception as e: logging.error("ERROR {} {} {}".format(patch_type, key, e)) elif patch_type == "glora": @@ -456,7 +456,7 @@ class ModelPatcher: try: weight += ((torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)) * alpha).reshape(weight.shape).type(weight.dtype) if dora_scale is not None: - weight = apply_weight_decompose(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight) + weight *= weight_decompose_scale(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight) except Exception as e: logging.error("ERROR {} {} {}".format(patch_type, key, e)) else: