diff --git a/comfy/weight_adapter/base.py b/comfy/weight_adapter/base.py index 9a7aa9de..54af3bab 100644 --- a/comfy/weight_adapter/base.py +++ b/comfy/weight_adapter/base.py @@ -1,6 +1,10 @@ +from typing import Optional + import torch import torch.nn as nn +import comfy.model_management + class WeightAdapterBase: name: str @@ -8,7 +12,7 @@ class WeightAdapterBase: weights: list[torch.Tensor] @classmethod - def load(cls, x: str, lora: dict[str, torch.Tensor]) -> "WeightAdapterBase" | None: + def load(cls, x: str, lora: dict[str, torch.Tensor]) -> Optional["WeightAdapterBase"]: raise NotImplementedError def to_train(self) -> "WeightAdapterTrainBase": @@ -33,3 +37,58 @@ class WeightAdapterTrainBase(nn.Module): super().__init__() # [TODO] Collaborate with LoRA training PR #7032 + + +def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function): + dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, intermediate_dtype) + lora_diff *= alpha + weight_calc = weight + function(lora_diff).type(weight.dtype) + weight_norm = ( + weight_calc.transpose(0, 1) + .reshape(weight_calc.shape[1], -1) + .norm(dim=1, keepdim=True) + .reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1)) + .transpose(0, 1) + ) + + weight_calc *= (dora_scale / weight_norm).type(weight.dtype) + if strength != 1.0: + weight_calc -= weight + weight += strength * (weight_calc) + else: + weight[:] = weight_calc + return weight + + +def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Tensor: + """ + Pad a tensor to a new shape with zeros. + + Args: + tensor (torch.Tensor): The original tensor to be padded. + new_shape (List[int]): The desired shape of the padded tensor. + + Returns: + torch.Tensor: A new tensor padded with zeros to the specified shape. + + Note: + If the new shape is smaller than the original tensor in any dimension, + the original tensor will be truncated in that dimension. + """ + if any([new_shape[i] < tensor.shape[i] for i in range(len(new_shape))]): + raise ValueError("The new shape must be larger than the original tensor in all dimensions") + + if len(new_shape) != len(tensor.shape): + raise ValueError("The new shape must have the same number of dimensions as the original tensor") + + # Create a new tensor filled with zeros + padded_tensor = torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) + + # Create slicing tuples for both tensors + orig_slices = tuple(slice(0, dim) for dim in tensor.shape) + new_slices = tuple(slice(0, dim) for dim in tensor.shape) + + # Copy the original tensor into the new tensor + padded_tensor[new_slices] = tensor[orig_slices] + + return padded_tensor diff --git a/comfy/weight_adapter/lora.py b/comfy/weight_adapter/lora.py index b79bfd82..9e00e9e6 100644 --- a/comfy/weight_adapter/lora.py +++ b/comfy/weight_adapter/lora.py @@ -1,11 +1,10 @@ import logging -import torch -import comfy.utils -import comfy.model_management -import comfy.model_base -from comfy.lora import weight_decompose, pad_tensor_to_shape +from optparse import Option +from typing import Optional -from .base import WeightAdapterBase +import torch +import comfy.model_management +from .base import WeightAdapterBase, weight_decompose, pad_tensor_to_shape class LoRAAdapter(WeightAdapterBase): @@ -23,7 +22,7 @@ class LoRAAdapter(WeightAdapterBase): alpha: float, dora_scale: torch.Tensor, loaded_keys: set[str] = None, - ) -> "LoRAAdapter" | None: + ) -> Optional["LoRAAdapter"]: if loaded_keys is None: loaded_keys = set()