mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-20 03:13:30 +00:00
Fix DORA strength.
This is a different version of #3298 with more correct behavior.
This commit is contained in:
parent
5b87369474
commit
ffc4b7c30e
@ -9,16 +9,26 @@ import comfy.model_management
|
|||||||
from comfy.types import UnetWrapperFunction
|
from comfy.types import UnetWrapperFunction
|
||||||
|
|
||||||
|
|
||||||
def weight_decompose_scale(dora_scale, weight):
|
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength):
|
||||||
|
dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32)
|
||||||
|
lora_diff *= alpha
|
||||||
|
weight_calc = weight + lora_diff.type(weight.dtype)
|
||||||
weight_norm = (
|
weight_norm = (
|
||||||
weight.transpose(0, 1)
|
weight_calc.transpose(0, 1)
|
||||||
.reshape(weight.shape[1], -1)
|
.reshape(weight_calc.shape[1], -1)
|
||||||
.norm(dim=1, keepdim=True)
|
.norm(dim=1, keepdim=True)
|
||||||
.reshape(weight.shape[1], *[1] * (weight.dim() - 1))
|
.reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1))
|
||||||
.transpose(0, 1)
|
.transpose(0, 1)
|
||||||
)
|
)
|
||||||
|
|
||||||
return (dora_scale / weight_norm).type(weight.dtype)
|
weight_calc *= (dora_scale / weight_norm).type(weight.dtype)
|
||||||
|
if strength != 1.0:
|
||||||
|
weight_calc -= weight
|
||||||
|
weight += strength * (weight_calc)
|
||||||
|
else:
|
||||||
|
weight[:] = weight_calc
|
||||||
|
return weight
|
||||||
|
|
||||||
|
|
||||||
def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None):
|
def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None):
|
||||||
to = model_options["transformer_options"].copy()
|
to = model_options["transformer_options"].copy()
|
||||||
@ -328,7 +338,7 @@ class ModelPatcher:
|
|||||||
|
|
||||||
def calculate_weight(self, patches, weight, key):
|
def calculate_weight(self, patches, weight, key):
|
||||||
for p in patches:
|
for p in patches:
|
||||||
alpha = p[0]
|
strength = p[0]
|
||||||
v = p[1]
|
v = p[1]
|
||||||
strength_model = p[2]
|
strength_model = p[2]
|
||||||
|
|
||||||
@ -346,26 +356,31 @@ class ModelPatcher:
|
|||||||
|
|
||||||
if patch_type == "diff":
|
if patch_type == "diff":
|
||||||
w1 = v[0]
|
w1 = v[0]
|
||||||
if alpha != 0.0:
|
if strength != 0.0:
|
||||||
if w1.shape != weight.shape:
|
if w1.shape != weight.shape:
|
||||||
logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
|
logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
|
||||||
else:
|
else:
|
||||||
weight += alpha * comfy.model_management.cast_to_device(w1, weight.device, weight.dtype)
|
weight += strength * comfy.model_management.cast_to_device(w1, weight.device, weight.dtype)
|
||||||
elif patch_type == "lora": #lora/locon
|
elif patch_type == "lora": #lora/locon
|
||||||
mat1 = comfy.model_management.cast_to_device(v[0], weight.device, torch.float32)
|
mat1 = comfy.model_management.cast_to_device(v[0], weight.device, torch.float32)
|
||||||
mat2 = comfy.model_management.cast_to_device(v[1], weight.device, torch.float32)
|
mat2 = comfy.model_management.cast_to_device(v[1], weight.device, torch.float32)
|
||||||
dora_scale = v[4]
|
dora_scale = v[4]
|
||||||
if v[2] is not None:
|
if v[2] is not None:
|
||||||
alpha *= v[2] / mat2.shape[0]
|
alpha = v[2] / mat2.shape[0]
|
||||||
|
else:
|
||||||
|
alpha = 1.0
|
||||||
|
|
||||||
if v[3] is not None:
|
if v[3] is not None:
|
||||||
#locon mid weights, hopefully the math is fine because I didn't properly test it
|
#locon mid weights, hopefully the math is fine because I didn't properly test it
|
||||||
mat3 = comfy.model_management.cast_to_device(v[3], weight.device, torch.float32)
|
mat3 = comfy.model_management.cast_to_device(v[3], weight.device, torch.float32)
|
||||||
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
|
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
|
||||||
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
|
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
|
||||||
try:
|
try:
|
||||||
weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype)
|
lora_diff = torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)).reshape(weight.shape)
|
||||||
if dora_scale is not None:
|
if dora_scale is not None:
|
||||||
weight *= weight_decompose_scale(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight)
|
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength)
|
||||||
|
else:
|
||||||
|
weight += ((strength * alpha) * lora_diff).type(weight.dtype)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
||||||
elif patch_type == "lokr":
|
elif patch_type == "lokr":
|
||||||
@ -402,19 +417,26 @@ class ModelPatcher:
|
|||||||
if len(w2.shape) == 4:
|
if len(w2.shape) == 4:
|
||||||
w1 = w1.unsqueeze(2).unsqueeze(2)
|
w1 = w1.unsqueeze(2).unsqueeze(2)
|
||||||
if v[2] is not None and dim is not None:
|
if v[2] is not None and dim is not None:
|
||||||
alpha *= v[2] / dim
|
alpha = v[2] / dim
|
||||||
|
else:
|
||||||
|
alpha = 1.0
|
||||||
|
|
||||||
try:
|
try:
|
||||||
weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype)
|
lora_diff = torch.kron(w1, w2).reshape(weight.shape)
|
||||||
if dora_scale is not None:
|
if dora_scale is not None:
|
||||||
weight *= weight_decompose_scale(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight)
|
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength)
|
||||||
|
else:
|
||||||
|
weight += ((strength * alpha) * lora_diff).type(weight.dtype)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
||||||
elif patch_type == "loha":
|
elif patch_type == "loha":
|
||||||
w1a = v[0]
|
w1a = v[0]
|
||||||
w1b = v[1]
|
w1b = v[1]
|
||||||
if v[2] is not None:
|
if v[2] is not None:
|
||||||
alpha *= v[2] / w1b.shape[0]
|
alpha = v[2] / w1b.shape[0]
|
||||||
|
else:
|
||||||
|
alpha = 1.0
|
||||||
|
|
||||||
w2a = v[3]
|
w2a = v[3]
|
||||||
w2b = v[4]
|
w2b = v[4]
|
||||||
dora_scale = v[7]
|
dora_scale = v[7]
|
||||||
@ -437,14 +459,18 @@ class ModelPatcher:
|
|||||||
comfy.model_management.cast_to_device(w2b, weight.device, torch.float32))
|
comfy.model_management.cast_to_device(w2b, weight.device, torch.float32))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype)
|
lora_diff = (m1 * m2).reshape(weight.shape)
|
||||||
if dora_scale is not None:
|
if dora_scale is not None:
|
||||||
weight *= weight_decompose_scale(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight)
|
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength)
|
||||||
|
else:
|
||||||
|
weight += ((strength * alpha) * lora_diff).type(weight.dtype)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
||||||
elif patch_type == "glora":
|
elif patch_type == "glora":
|
||||||
if v[4] is not None:
|
if v[4] is not None:
|
||||||
alpha *= v[4] / v[0].shape[0]
|
alpha = v[4] / v[0].shape[0]
|
||||||
|
else:
|
||||||
|
alpha = 1.0
|
||||||
|
|
||||||
dora_scale = v[5]
|
dora_scale = v[5]
|
||||||
|
|
||||||
@ -454,9 +480,11 @@ class ModelPatcher:
|
|||||||
b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, torch.float32)
|
b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, torch.float32)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
weight += ((torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)) * alpha).reshape(weight.shape).type(weight.dtype)
|
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)).reshape(weight.shape)
|
||||||
if dora_scale is not None:
|
if dora_scale is not None:
|
||||||
weight *= weight_decompose_scale(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight)
|
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength)
|
||||||
|
else:
|
||||||
|
weight += ((strength * alpha) * lora_diff).type(weight.dtype)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
||||||
else:
|
else:
|
||||||
|
Loading…
Reference in New Issue
Block a user