From 7718ada4eddf101d088b69e159011e4108286b5b Mon Sep 17 00:00:00 2001 From: Chenlei Hu Date: Wed, 22 May 2024 02:07:27 -0400 Subject: [PATCH] Add type annotation UnetWrapperFunction (#3531) * Add type annotation UnetWrapperFunction * nit * Add types.py --- comfy/model_patcher.py | 4 +++- comfy/types.py | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 35 insertions(+), 1 deletion(-) create mode 100644 comfy/types.py diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 35ede5ee..c397ee51 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -6,6 +6,8 @@ import uuid import comfy.utils import comfy.model_management +from comfy.types import UnetWrapperFunction + def apply_weight_decompose(dora_scale, weight): weight_norm = ( @@ -117,7 +119,7 @@ class ModelPatcher: if disable_cfg1_optimization: self.model_options["disable_cfg1_optimization"] = True - def set_model_unet_function_wrapper(self, unet_wrapper_function): + def set_model_unet_function_wrapper(self, unet_wrapper_function: UnetWrapperFunction): self.model_options["model_function_wrapper"] = unet_wrapper_function def set_model_denoise_mask_function(self, denoise_mask_function): diff --git a/comfy/types.py b/comfy/types.py new file mode 100644 index 00000000..a8a3d29f --- /dev/null +++ b/comfy/types.py @@ -0,0 +1,32 @@ +import torch +from typing import Callable, Protocol, TypedDict, Optional, List + + +class UnetApplyFunction(Protocol): + """Function signature protocol on comfy.model_base.BaseModel.apply_model""" + + def __call__(self, x: torch.Tensor, t: torch.Tensor, **kwargs) -> torch.Tensor: + pass + + +class UnetApplyConds(TypedDict): + """Optional conditions for unet apply function.""" + + c_concat: Optional[torch.Tensor] + c_crossattn: Optional[torch.Tensor] + control: Optional[torch.Tensor] + transformer_options: Optional[dict] + + +class UnetParams(TypedDict): + # Tensor of shape [B, C, H, W] + input: torch.Tensor + # Tensor of shape [B] + timestep: torch.Tensor + c: UnetApplyConds + # List of [0, 1], [0], [1], ... + # 0 means unconditional, 1 means conditional + cond_or_uncond: List[int] + + +UnetWrapperFunction = Callable[[UnetApplyFunction, UnetParams], torch.Tensor]