Reduce memory usage when applying DORA: #3557

This commit is contained in:
comfyanonymous 2024-05-24 23:36:48 -04:00
parent 58c9838274
commit efa5a711b2

View File

@ -9,7 +9,7 @@ import comfy.model_management
from comfy.types import UnetWrapperFunction
def apply_weight_decompose(dora_scale, weight):
def weight_decompose_scale(dora_scale, weight):
weight_norm = (
weight.transpose(0, 1)
.reshape(weight.shape[1], -1)
@ -18,7 +18,7 @@ def apply_weight_decompose(dora_scale, weight):
.transpose(0, 1)
)
return weight * (dora_scale / weight_norm).type(weight.dtype)
return (dora_scale / weight_norm).type(weight.dtype)
def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None):
to = model_options["transformer_options"].copy()
@ -365,7 +365,7 @@ class ModelPatcher:
try:
weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype)
if dora_scale is not None:
weight = apply_weight_decompose(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight)
weight *= weight_decompose_scale(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight)
except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e))
elif patch_type == "lokr":
@ -407,7 +407,7 @@ class ModelPatcher:
try:
weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype)
if dora_scale is not None:
weight = apply_weight_decompose(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight)
weight *= weight_decompose_scale(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight)
except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e))
elif patch_type == "loha":
@ -439,7 +439,7 @@ class ModelPatcher:
try:
weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype)
if dora_scale is not None:
weight = apply_weight_decompose(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight)
weight *= weight_decompose_scale(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight)
except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e))
elif patch_type == "glora":
@ -456,7 +456,7 @@ class ModelPatcher:
try:
weight += ((torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)) * alpha).reshape(weight.shape).type(weight.dtype)
if dora_scale is not None:
weight = apply_weight_decompose(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight)
weight *= weight_decompose_scale(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight)
except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e))
else: