mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 02:15:17 +00:00
0ee322ec5f
* Added hook_patches to ModelPatcher for weights (model) * Initial changes to calc_cond_batch to eventually support hook_patches * Added current_patcher property to BaseModel * Consolidated add_hook_patches_as_diffs into add_hook_patches func, fixed fp8 support for model-as-lora feature * Added call to initialize_timesteps on hooks in process_conds func, and added call prepare current keyframe on hooks in calc_cond_batch * Added default_conds support in calc_cond_batch func * Added initial set of hook-related nodes, added code to register hooks for loras/model-as-loras, small renaming/refactoring * Made CLIP work with hook patches * Added initial hook scheduling nodes, small renaming/refactoring * Fixed MaxSpeed and default conds implementations * Added support for adding weight hooks that aren't registered on the ModelPatcher at sampling time * Made Set Clip Hooks node work with hooks from Create Hook nodes, began work on better Create Hook Model As LoRA node * Initial work on adding 'model_as_lora' lora type to calculate_weight * Continued work on simpler Create Hook Model As LoRA node, started to implement ModelPatcher callbacks, attachments, and additional_models * Fix incorrect ref to create_hook_patches_clone after moving function * Added injections support to ModelPatcher + necessary bookkeeping, added additional_models support in ModelPatcher, conds, and hooks * Added wrappers to ModelPatcher to facilitate standardized function wrapping * Started scaffolding for other hook types, refactored get_hooks_from_cond to organize hooks by type * Fix skip_until_exit logic bug breaking injection after first run of model * Updated clone_has_same_weights function to account for new ModelPatcher properties, improved AutoPatcherEjector usage in partially_load * Added WrapperExecutor for non-classbound functions, added calc_cond_batch wrappers * Refactored callbacks+wrappers to allow storing lists by id * Added forward_timestep_embed_patch type, added helper functions on ModelPatcher for emb_patch and forward_timestep_embed_patch, added helper functions for removing callbacks/wrappers/additional_models by key, added custom_should_register prop to hooks * Added get_attachment func on ModelPatcher * Implement basic MemoryCounter system for determing with cached weights due to hooks should be offloaded in hooks_backup * Modified ControlNet/T2IAdapter get_control function to receive transformer_options as additional parameter, made the model_options stored in extra_args in inner_sample be a clone of the original model_options instead of same ref * Added create_model_options_clone func, modified type annotations to use __future__ so that I can use the better type annotations * Refactored WrapperExecutor code to remove need for WrapperClassExecutor (now gone), added sampler.sample wrapper (pending review, will likely keep but will see what hacks this could currently let me get rid of in ACN/ADE) * Added Combine versions of Cond/Cond Pair Set Props nodes, renamed Pair Cond to Cond Pair, fixed default conds never applying hooks (due to hooks key typo) * Renamed Create Hook Model As LoRA nodes to make the test node the main one (more changes pending) * Added uuid to conds in CFGGuider and uuids to transformer_options to allow uniquely identifying conds in batches during sampling * Fixed models not being unloaded properly due to current_patcher reference; the current ComfyUI model cleanup code requires that nothing else has a reference to the ModelPatcher instances * Fixed default conds not respecting hook keyframes, made keyframes not reset cache when strength is unchanged, fixed Cond Set Default Combine throwing error, fixed model-as-lora throwing error during calculate_weight after a recent ComfyUI update, small refactoring/scaffolding changes for hooks * Changed CreateHookModelAsLoraTest to be the new CreateHookModelAsLora, rename old ones as 'direct' and will be removed prior to merge * Added initial support within CLIP Text Encode (Prompt) node for scheduling weight hook CLIP strength via clip_start_percent/clip_end_percent on conds, added schedule_clip toggle to Set CLIP Hooks node, small cleanup/fixes * Fix range check in get_hooks_for_clip_schedule so that proper keyframes get assigned to corresponding ranges * Optimized CLIP hook scheduling to treat same strength as same keyframe * Less fragile memory management. * Make encode_from_tokens_scheduled call cleaner, rollback change in model_patcher.py for hook_patches_backup dict * Fix issue. * Remove useless function. * Prevent and detect some types of memory leaks. * Run garbage collector when switching workflow if needed. * Moved WrappersMP/CallbacksMP/WrapperExecutor to patcher_extension.py * Refactored code to store wrappers and callbacks in transformer_options, added apply_model and diffusion_model.forward wrappers * Fix issue. * Refactored hooks in calc_cond_batch to be part of get_area_and_mult tuple, added extra_hooks to ControlBase to allow custom controlnets w/ hooks, small cleanup and renaming * Fixed inconsistency of results when schedule_clip is set to False, small renaming/typo fixing, added initial support for ControlNet extra_hooks to work in tandem with normal cond hooks, initial work on calc_cond_batch merging all subdicts in returned transformer_options * Modified callbacks and wrappers so that unregistered types can be used, allowing custom_nodes to have their own unique callbacks/wrappers if desired * Updated different hook types to reflect actual progress of implementation, initial scaffolding for working WrapperHook functionality * Fixed existing weight hook_patches (pre-registered) not working properly for CLIP * Removed Register/Direct hook nodes since they were present only for testing, removed diff-related weight hook calculation as improved_memory removes unload_model_clones and using sample time registered hooks is less hacky * Added clip scheduling support to all other native ComfyUI text encoding nodes (sdxl, flux, hunyuan, sd3) * Made WrapperHook functional, added another wrapper/callback getter, added ON_DETACH callback to ModelPatcher * Made opt_hooks append by default instead of replace, renamed comfy.hooks set functions to be more accurate * Added apply_to_conds to Set CLIP Hooks, modified relevant code to allow text encoding to automatically apply hooks to output conds when apply_to_conds is set to True * Fix cached_hook_patches not respecting target_device/memory_counter results * Fixed issue with setting weights from hooks instead of copying them, added additional memory_counter check when caching hook patches * Remove unnecessary torch.no_grad calls for hook patches * Increased MemoryCounter minimum memory to leave free by *2 until a better way to get inference memory estimate of currently loaded models exists * For encode_from_tokens_scheduled, allow start_percent and end_percent in add_dict to limit which scheduled conds get encoded for optimization purposes * Removed a .to call on results of calculate_weight in patch_hook_weight_to_device that was screwing up the intermediate results for fp8 prior to being passed into stochastic_rounding call * Made encode_from_tokens_scheduled work when no hooks are set on patcher * Small cleanup of comments * Turn off hook patch caching when only 1 hook present in sampling, replace some current_hook = None with calls to self.patch_hooks(None) instead to avoid a potential edge case * On Cond/Cond Pair nodes, removed opt_ prefix from optional inputs * Allow both FLOATS and FLOAT for floats_strength input * Revert change, does not work * Made patch_hook_weight_to_device respect set_func and convert_func * Make discard_model_sampling True by default * Add changes manually from 'master' so merge conflict resolution goes more smoothly * Cleaned up text encode nodes with just a single clip.encode_from_tokens_scheduled call * Make sure encode_from_tokens_scheduled will respect use_clip_schedule on clip * Made nodes in nodes_hooks be marked as experimental (beta) * Add get_nested_additional_models for cases where additional_models could have their own additional_models, and add robustness for circular additional_models references * Made finalize_default_conds area math consistent with other sampling code * Changed 'opt_hooks' input of Cond/Cond Pair Set Default Combine nodes to 'hooks' * Remove a couple old TODO's and a no longer necessary workaround
809 lines
35 KiB
Python
809 lines
35 KiB
Python
"""
|
|
This file is part of ComfyUI.
|
|
Copyright (C) 2024 Comfy
|
|
|
|
This program is free software: you can redistribute it and/or modify
|
|
it under the terms of the GNU General Public License as published by
|
|
the Free Software Foundation, either version 3 of the License, or
|
|
(at your option) any later version.
|
|
|
|
This program is distributed in the hope that it will be useful,
|
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
GNU General Public License for more details.
|
|
|
|
You should have received a copy of the GNU General Public License
|
|
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|
"""
|
|
|
|
import torch
|
|
import logging
|
|
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
|
|
from comfy.ldm.cascade.stage_c import StageC
|
|
from comfy.ldm.cascade.stage_b import StageB
|
|
from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
|
|
from comfy.ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
|
|
from comfy.ldm.modules.diffusionmodules.mmdit import OpenAISignatureMMDITWrapper
|
|
import comfy.ldm.genmo.joint_model.asymm_models_joint
|
|
import comfy.ldm.aura.mmdit
|
|
import comfy.ldm.hydit.models
|
|
import comfy.ldm.audio.dit
|
|
import comfy.ldm.audio.embedders
|
|
import comfy.ldm.flux.model
|
|
import comfy.ldm.lightricks.model
|
|
|
|
import comfy.model_management
|
|
import comfy.patcher_extension
|
|
import comfy.conds
|
|
import comfy.ops
|
|
from enum import Enum
|
|
from . import utils
|
|
import comfy.latent_formats
|
|
import math
|
|
from typing import TYPE_CHECKING
|
|
if TYPE_CHECKING:
|
|
from comfy.model_patcher import ModelPatcher
|
|
|
|
class ModelType(Enum):
|
|
EPS = 1
|
|
V_PREDICTION = 2
|
|
V_PREDICTION_EDM = 3
|
|
STABLE_CASCADE = 4
|
|
EDM = 5
|
|
FLOW = 6
|
|
V_PREDICTION_CONTINUOUS = 7
|
|
FLUX = 8
|
|
|
|
|
|
from comfy.model_sampling import EPS, V_PREDICTION, EDM, ModelSamplingDiscrete, ModelSamplingContinuousEDM, StableCascadeSampling, ModelSamplingContinuousV
|
|
|
|
|
|
def model_sampling(model_config, model_type):
|
|
s = ModelSamplingDiscrete
|
|
|
|
if model_type == ModelType.EPS:
|
|
c = EPS
|
|
elif model_type == ModelType.V_PREDICTION:
|
|
c = V_PREDICTION
|
|
elif model_type == ModelType.V_PREDICTION_EDM:
|
|
c = V_PREDICTION
|
|
s = ModelSamplingContinuousEDM
|
|
elif model_type == ModelType.FLOW:
|
|
c = comfy.model_sampling.CONST
|
|
s = comfy.model_sampling.ModelSamplingDiscreteFlow
|
|
elif model_type == ModelType.STABLE_CASCADE:
|
|
c = EPS
|
|
s = StableCascadeSampling
|
|
elif model_type == ModelType.EDM:
|
|
c = EDM
|
|
s = ModelSamplingContinuousEDM
|
|
elif model_type == ModelType.V_PREDICTION_CONTINUOUS:
|
|
c = V_PREDICTION
|
|
s = ModelSamplingContinuousV
|
|
elif model_type == ModelType.FLUX:
|
|
c = comfy.model_sampling.CONST
|
|
s = comfy.model_sampling.ModelSamplingFlux
|
|
|
|
class ModelSampling(s, c):
|
|
pass
|
|
|
|
return ModelSampling(model_config)
|
|
|
|
|
|
class BaseModel(torch.nn.Module):
|
|
def __init__(self, model_config, model_type=ModelType.EPS, device=None, unet_model=UNetModel):
|
|
super().__init__()
|
|
|
|
unet_config = model_config.unet_config
|
|
self.latent_format = model_config.latent_format
|
|
self.model_config = model_config
|
|
self.manual_cast_dtype = model_config.manual_cast_dtype
|
|
self.device = device
|
|
self.current_patcher: 'ModelPatcher' = None
|
|
|
|
if not unet_config.get("disable_unet_model_creation", False):
|
|
if model_config.custom_operations is None:
|
|
fp8 = model_config.optimizations.get("fp8", model_config.scaled_fp8 is not None)
|
|
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8)
|
|
else:
|
|
operations = model_config.custom_operations
|
|
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
|
|
if comfy.model_management.force_channels_last():
|
|
self.diffusion_model.to(memory_format=torch.channels_last)
|
|
logging.debug("using channels last mode for diffusion model")
|
|
logging.info("model weight dtype {}, manual cast: {}".format(self.get_dtype(), self.manual_cast_dtype))
|
|
self.model_type = model_type
|
|
self.model_sampling = model_sampling(model_config, model_type)
|
|
|
|
self.adm_channels = unet_config.get("adm_in_channels", None)
|
|
if self.adm_channels is None:
|
|
self.adm_channels = 0
|
|
|
|
self.concat_keys = ()
|
|
logging.info("model_type {}".format(model_type.name))
|
|
logging.debug("adm {}".format(self.adm_channels))
|
|
self.memory_usage_factor = model_config.memory_usage_factor
|
|
|
|
def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
|
|
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
|
self._apply_model,
|
|
self,
|
|
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.APPLY_MODEL, transformer_options)
|
|
).execute(x, t, c_concat, c_crossattn, control, transformer_options, **kwargs)
|
|
|
|
def _apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
|
|
sigma = t
|
|
xc = self.model_sampling.calculate_input(sigma, x)
|
|
if c_concat is not None:
|
|
xc = torch.cat([xc] + [c_concat], dim=1)
|
|
|
|
context = c_crossattn
|
|
dtype = self.get_dtype()
|
|
|
|
if self.manual_cast_dtype is not None:
|
|
dtype = self.manual_cast_dtype
|
|
|
|
xc = xc.to(dtype)
|
|
t = self.model_sampling.timestep(t).float()
|
|
context = context.to(dtype)
|
|
extra_conds = {}
|
|
for o in kwargs:
|
|
extra = kwargs[o]
|
|
if hasattr(extra, "dtype"):
|
|
if extra.dtype != torch.int and extra.dtype != torch.long:
|
|
extra = extra.to(dtype)
|
|
extra_conds[o] = extra
|
|
|
|
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float()
|
|
return self.model_sampling.calculate_denoised(sigma, model_output, x)
|
|
|
|
def get_dtype(self):
|
|
return self.diffusion_model.dtype
|
|
|
|
def is_adm(self):
|
|
return self.adm_channels > 0
|
|
|
|
def encode_adm(self, **kwargs):
|
|
return None
|
|
|
|
def concat_cond(self, **kwargs):
|
|
if len(self.concat_keys) > 0:
|
|
cond_concat = []
|
|
denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
|
|
concat_latent_image = kwargs.get("concat_latent_image", None)
|
|
if concat_latent_image is None:
|
|
concat_latent_image = kwargs.get("latent_image", None)
|
|
else:
|
|
concat_latent_image = self.process_latent_in(concat_latent_image)
|
|
|
|
noise = kwargs.get("noise", None)
|
|
device = kwargs["device"]
|
|
|
|
if concat_latent_image.shape[1:] != noise.shape[1:]:
|
|
concat_latent_image = utils.common_upscale(concat_latent_image, noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
|
|
|
concat_latent_image = utils.resize_to_batch_size(concat_latent_image, noise.shape[0])
|
|
|
|
if denoise_mask is not None:
|
|
if len(denoise_mask.shape) == len(noise.shape):
|
|
denoise_mask = denoise_mask[:,:1]
|
|
|
|
denoise_mask = denoise_mask.reshape((-1, 1, denoise_mask.shape[-2], denoise_mask.shape[-1]))
|
|
if denoise_mask.shape[-2:] != noise.shape[-2:]:
|
|
denoise_mask = utils.common_upscale(denoise_mask, noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
|
denoise_mask = utils.resize_to_batch_size(denoise_mask.round(), noise.shape[0])
|
|
|
|
for ck in self.concat_keys:
|
|
if denoise_mask is not None:
|
|
if ck == "mask":
|
|
cond_concat.append(denoise_mask.to(device))
|
|
elif ck == "masked_image":
|
|
cond_concat.append(concat_latent_image.to(device)) #NOTE: the latent_image should be masked by the mask in pixel space
|
|
else:
|
|
if ck == "mask":
|
|
cond_concat.append(torch.ones_like(noise)[:,:1])
|
|
elif ck == "masked_image":
|
|
cond_concat.append(self.blank_inpaint_image_like(noise))
|
|
data = torch.cat(cond_concat, dim=1)
|
|
return data
|
|
return None
|
|
|
|
def extra_conds(self, **kwargs):
|
|
out = {}
|
|
concat_cond = self.concat_cond(**kwargs)
|
|
if concat_cond is not None:
|
|
out['c_concat'] = comfy.conds.CONDNoiseShape(concat_cond)
|
|
|
|
adm = self.encode_adm(**kwargs)
|
|
if adm is not None:
|
|
out['y'] = comfy.conds.CONDRegular(adm)
|
|
|
|
cross_attn = kwargs.get("cross_attn", None)
|
|
if cross_attn is not None:
|
|
out['c_crossattn'] = comfy.conds.CONDCrossAttn(cross_attn)
|
|
|
|
cross_attn_cnet = kwargs.get("cross_attn_controlnet", None)
|
|
if cross_attn_cnet is not None:
|
|
out['crossattn_controlnet'] = comfy.conds.CONDCrossAttn(cross_attn_cnet)
|
|
|
|
c_concat = kwargs.get("noise_concat", None)
|
|
if c_concat is not None:
|
|
out['c_concat'] = comfy.conds.CONDNoiseShape(c_concat)
|
|
|
|
return out
|
|
|
|
def load_model_weights(self, sd, unet_prefix=""):
|
|
to_load = {}
|
|
keys = list(sd.keys())
|
|
for k in keys:
|
|
if k.startswith(unet_prefix):
|
|
to_load[k[len(unet_prefix):]] = sd.pop(k)
|
|
|
|
to_load = self.model_config.process_unet_state_dict(to_load)
|
|
m, u = self.diffusion_model.load_state_dict(to_load, strict=False)
|
|
if len(m) > 0:
|
|
logging.warning("unet missing: {}".format(m))
|
|
|
|
if len(u) > 0:
|
|
logging.warning("unet unexpected: {}".format(u))
|
|
del to_load
|
|
return self
|
|
|
|
def process_latent_in(self, latent):
|
|
return self.latent_format.process_in(latent)
|
|
|
|
def process_latent_out(self, latent):
|
|
return self.latent_format.process_out(latent)
|
|
|
|
def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None):
|
|
extra_sds = []
|
|
if clip_state_dict is not None:
|
|
extra_sds.append(self.model_config.process_clip_state_dict_for_saving(clip_state_dict))
|
|
if vae_state_dict is not None:
|
|
extra_sds.append(self.model_config.process_vae_state_dict_for_saving(vae_state_dict))
|
|
if clip_vision_state_dict is not None:
|
|
extra_sds.append(self.model_config.process_clip_vision_state_dict_for_saving(clip_vision_state_dict))
|
|
|
|
unet_state_dict = self.diffusion_model.state_dict()
|
|
|
|
if self.model_config.scaled_fp8 is not None:
|
|
unet_state_dict["scaled_fp8"] = torch.tensor([], dtype=self.model_config.scaled_fp8)
|
|
|
|
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
|
|
|
|
if self.model_type == ModelType.V_PREDICTION:
|
|
unet_state_dict["v_pred"] = torch.tensor([])
|
|
|
|
for sd in extra_sds:
|
|
unet_state_dict.update(sd)
|
|
|
|
return unet_state_dict
|
|
|
|
def set_inpaint(self):
|
|
self.concat_keys = ("mask", "masked_image")
|
|
def blank_inpaint_image_like(latent_image):
|
|
blank_image = torch.ones_like(latent_image)
|
|
# these are the values for "zero" in pixel space translated to latent space
|
|
blank_image[:,0] *= 0.8223
|
|
blank_image[:,1] *= -0.6876
|
|
blank_image[:,2] *= 0.6364
|
|
blank_image[:,3] *= 0.1380
|
|
return blank_image
|
|
self.blank_inpaint_image_like = blank_inpaint_image_like
|
|
|
|
def memory_required(self, input_shape):
|
|
if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention():
|
|
dtype = self.get_dtype()
|
|
if self.manual_cast_dtype is not None:
|
|
dtype = self.manual_cast_dtype
|
|
#TODO: this needs to be tweaked
|
|
area = input_shape[0] * math.prod(input_shape[2:])
|
|
return (area * comfy.model_management.dtype_size(dtype) * 0.01 * self.memory_usage_factor) * (1024 * 1024)
|
|
else:
|
|
#TODO: this formula might be too aggressive since I tweaked the sub-quad and split algorithms to use less memory.
|
|
area = input_shape[0] * math.prod(input_shape[2:])
|
|
return (area * 0.15 * self.memory_usage_factor) * (1024 * 1024)
|
|
|
|
|
|
def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge=0.0, seed=None):
|
|
adm_inputs = []
|
|
weights = []
|
|
noise_aug = []
|
|
for unclip_cond in unclip_conditioning:
|
|
for adm_cond in unclip_cond["clip_vision_output"].image_embeds:
|
|
weight = unclip_cond["strength"]
|
|
noise_augment = unclip_cond["noise_augmentation"]
|
|
noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment)
|
|
c_adm, noise_level_emb = noise_augmentor(adm_cond.to(device), noise_level=torch.tensor([noise_level], device=device), seed=seed)
|
|
adm_out = torch.cat((c_adm, noise_level_emb), 1) * weight
|
|
weights.append(weight)
|
|
noise_aug.append(noise_augment)
|
|
adm_inputs.append(adm_out)
|
|
|
|
if len(noise_aug) > 1:
|
|
adm_out = torch.stack(adm_inputs).sum(0)
|
|
noise_augment = noise_augment_merge
|
|
noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment)
|
|
c_adm, noise_level_emb = noise_augmentor(adm_out[:, :noise_augmentor.time_embed.dim], noise_level=torch.tensor([noise_level], device=device))
|
|
adm_out = torch.cat((c_adm, noise_level_emb), 1)
|
|
|
|
return adm_out
|
|
|
|
class SD21UNCLIP(BaseModel):
|
|
def __init__(self, model_config, noise_aug_config, model_type=ModelType.V_PREDICTION, device=None):
|
|
super().__init__(model_config, model_type, device=device)
|
|
self.noise_augmentor = CLIPEmbeddingNoiseAugmentation(**noise_aug_config)
|
|
|
|
def encode_adm(self, **kwargs):
|
|
unclip_conditioning = kwargs.get("unclip_conditioning", None)
|
|
device = kwargs["device"]
|
|
if unclip_conditioning is None:
|
|
return torch.zeros((1, self.adm_channels))
|
|
else:
|
|
return unclip_adm(unclip_conditioning, device, self.noise_augmentor, kwargs.get("unclip_noise_augment_merge", 0.05), kwargs.get("seed", 0) - 10)
|
|
|
|
def sdxl_pooled(args, noise_augmentor):
|
|
if "unclip_conditioning" in args:
|
|
return unclip_adm(args.get("unclip_conditioning", None), args["device"], noise_augmentor, seed=args.get("seed", 0) - 10)[:,:1280]
|
|
else:
|
|
return args["pooled_output"]
|
|
|
|
class SDXLRefiner(BaseModel):
|
|
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
|
|
super().__init__(model_config, model_type, device=device)
|
|
self.embedder = Timestep(256)
|
|
self.noise_augmentor = CLIPEmbeddingNoiseAugmentation(**{"noise_schedule_config": {"timesteps": 1000, "beta_schedule": "squaredcos_cap_v2"}, "timestep_dim": 1280})
|
|
|
|
def encode_adm(self, **kwargs):
|
|
clip_pooled = sdxl_pooled(kwargs, self.noise_augmentor)
|
|
width = kwargs.get("width", 768)
|
|
height = kwargs.get("height", 768)
|
|
crop_w = kwargs.get("crop_w", 0)
|
|
crop_h = kwargs.get("crop_h", 0)
|
|
|
|
if kwargs.get("prompt_type", "") == "negative":
|
|
aesthetic_score = kwargs.get("aesthetic_score", 2.5)
|
|
else:
|
|
aesthetic_score = kwargs.get("aesthetic_score", 6)
|
|
|
|
out = []
|
|
out.append(self.embedder(torch.Tensor([height])))
|
|
out.append(self.embedder(torch.Tensor([width])))
|
|
out.append(self.embedder(torch.Tensor([crop_h])))
|
|
out.append(self.embedder(torch.Tensor([crop_w])))
|
|
out.append(self.embedder(torch.Tensor([aesthetic_score])))
|
|
flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1)
|
|
return torch.cat((clip_pooled.to(flat.device), flat), dim=1)
|
|
|
|
class SDXL(BaseModel):
|
|
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
|
|
super().__init__(model_config, model_type, device=device)
|
|
self.embedder = Timestep(256)
|
|
self.noise_augmentor = CLIPEmbeddingNoiseAugmentation(**{"noise_schedule_config": {"timesteps": 1000, "beta_schedule": "squaredcos_cap_v2"}, "timestep_dim": 1280})
|
|
|
|
def encode_adm(self, **kwargs):
|
|
clip_pooled = sdxl_pooled(kwargs, self.noise_augmentor)
|
|
width = kwargs.get("width", 768)
|
|
height = kwargs.get("height", 768)
|
|
crop_w = kwargs.get("crop_w", 0)
|
|
crop_h = kwargs.get("crop_h", 0)
|
|
target_width = kwargs.get("target_width", width)
|
|
target_height = kwargs.get("target_height", height)
|
|
|
|
out = []
|
|
out.append(self.embedder(torch.Tensor([height])))
|
|
out.append(self.embedder(torch.Tensor([width])))
|
|
out.append(self.embedder(torch.Tensor([crop_h])))
|
|
out.append(self.embedder(torch.Tensor([crop_w])))
|
|
out.append(self.embedder(torch.Tensor([target_height])))
|
|
out.append(self.embedder(torch.Tensor([target_width])))
|
|
flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1)
|
|
return torch.cat((clip_pooled.to(flat.device), flat), dim=1)
|
|
|
|
|
|
class SVD_img2vid(BaseModel):
|
|
def __init__(self, model_config, model_type=ModelType.V_PREDICTION_EDM, device=None):
|
|
super().__init__(model_config, model_type, device=device)
|
|
self.embedder = Timestep(256)
|
|
|
|
def encode_adm(self, **kwargs):
|
|
fps_id = kwargs.get("fps", 6) - 1
|
|
motion_bucket_id = kwargs.get("motion_bucket_id", 127)
|
|
augmentation = kwargs.get("augmentation_level", 0)
|
|
|
|
out = []
|
|
out.append(self.embedder(torch.Tensor([fps_id])))
|
|
out.append(self.embedder(torch.Tensor([motion_bucket_id])))
|
|
out.append(self.embedder(torch.Tensor([augmentation])))
|
|
|
|
flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0)
|
|
return flat
|
|
|
|
def extra_conds(self, **kwargs):
|
|
out = {}
|
|
adm = self.encode_adm(**kwargs)
|
|
if adm is not None:
|
|
out['y'] = comfy.conds.CONDRegular(adm)
|
|
|
|
latent_image = kwargs.get("concat_latent_image", None)
|
|
noise = kwargs.get("noise", None)
|
|
device = kwargs["device"]
|
|
|
|
if latent_image is None:
|
|
latent_image = torch.zeros_like(noise)
|
|
|
|
if latent_image.shape[1:] != noise.shape[1:]:
|
|
latent_image = utils.common_upscale(latent_image, noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
|
|
|
latent_image = utils.resize_to_batch_size(latent_image, noise.shape[0])
|
|
|
|
out['c_concat'] = comfy.conds.CONDNoiseShape(latent_image)
|
|
|
|
cross_attn = kwargs.get("cross_attn", None)
|
|
if cross_attn is not None:
|
|
out['c_crossattn'] = comfy.conds.CONDCrossAttn(cross_attn)
|
|
|
|
if "time_conditioning" in kwargs:
|
|
out["time_context"] = comfy.conds.CONDCrossAttn(kwargs["time_conditioning"])
|
|
|
|
out['num_video_frames'] = comfy.conds.CONDConstant(noise.shape[0])
|
|
return out
|
|
|
|
class SV3D_u(SVD_img2vid):
|
|
def encode_adm(self, **kwargs):
|
|
augmentation = kwargs.get("augmentation_level", 0)
|
|
|
|
out = []
|
|
out.append(self.embedder(torch.flatten(torch.Tensor([augmentation]))))
|
|
|
|
flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0)
|
|
return flat
|
|
|
|
class SV3D_p(SVD_img2vid):
|
|
def __init__(self, model_config, model_type=ModelType.V_PREDICTION_EDM, device=None):
|
|
super().__init__(model_config, model_type, device=device)
|
|
self.embedder_512 = Timestep(512)
|
|
|
|
def encode_adm(self, **kwargs):
|
|
augmentation = kwargs.get("augmentation_level", 0)
|
|
elevation = kwargs.get("elevation", 0) #elevation and azimuth are in degrees here
|
|
azimuth = kwargs.get("azimuth", 0)
|
|
noise = kwargs.get("noise", None)
|
|
|
|
out = []
|
|
out.append(self.embedder(torch.flatten(torch.Tensor([augmentation]))))
|
|
out.append(self.embedder_512(torch.deg2rad(torch.fmod(torch.flatten(90 - torch.Tensor([elevation])), 360.0))))
|
|
out.append(self.embedder_512(torch.deg2rad(torch.fmod(torch.flatten(torch.Tensor([azimuth])), 360.0))))
|
|
|
|
out = list(map(lambda a: utils.resize_to_batch_size(a, noise.shape[0]), out))
|
|
return torch.cat(out, dim=1)
|
|
|
|
|
|
class Stable_Zero123(BaseModel):
|
|
def __init__(self, model_config, model_type=ModelType.EPS, device=None, cc_projection_weight=None, cc_projection_bias=None):
|
|
super().__init__(model_config, model_type, device=device)
|
|
self.cc_projection = comfy.ops.manual_cast.Linear(cc_projection_weight.shape[1], cc_projection_weight.shape[0], dtype=self.get_dtype(), device=device)
|
|
self.cc_projection.weight.copy_(cc_projection_weight)
|
|
self.cc_projection.bias.copy_(cc_projection_bias)
|
|
|
|
def extra_conds(self, **kwargs):
|
|
out = {}
|
|
|
|
latent_image = kwargs.get("concat_latent_image", None)
|
|
noise = kwargs.get("noise", None)
|
|
|
|
if latent_image is None:
|
|
latent_image = torch.zeros_like(noise)
|
|
|
|
if latent_image.shape[1:] != noise.shape[1:]:
|
|
latent_image = utils.common_upscale(latent_image, noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
|
|
|
latent_image = utils.resize_to_batch_size(latent_image, noise.shape[0])
|
|
|
|
out['c_concat'] = comfy.conds.CONDNoiseShape(latent_image)
|
|
|
|
cross_attn = kwargs.get("cross_attn", None)
|
|
if cross_attn is not None:
|
|
if cross_attn.shape[-1] != 768:
|
|
cross_attn = self.cc_projection(cross_attn)
|
|
out['c_crossattn'] = comfy.conds.CONDCrossAttn(cross_attn)
|
|
return out
|
|
|
|
class SD_X4Upscaler(BaseModel):
|
|
def __init__(self, model_config, model_type=ModelType.V_PREDICTION, device=None):
|
|
super().__init__(model_config, model_type, device=device)
|
|
self.noise_augmentor = ImageConcatWithNoiseAugmentation(noise_schedule_config={"linear_start": 0.0001, "linear_end": 0.02}, max_noise_level=350)
|
|
|
|
def extra_conds(self, **kwargs):
|
|
out = {}
|
|
|
|
image = kwargs.get("concat_image", None)
|
|
noise = kwargs.get("noise", None)
|
|
noise_augment = kwargs.get("noise_augmentation", 0.0)
|
|
device = kwargs["device"]
|
|
seed = kwargs["seed"] - 10
|
|
|
|
noise_level = round((self.noise_augmentor.max_noise_level) * noise_augment)
|
|
|
|
if image is None:
|
|
image = torch.zeros_like(noise)[:,:3]
|
|
|
|
if image.shape[1:] != noise.shape[1:]:
|
|
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
|
|
|
noise_level = torch.tensor([noise_level], device=device)
|
|
if noise_augment > 0:
|
|
image, noise_level = self.noise_augmentor(image.to(device), noise_level=noise_level, seed=seed)
|
|
|
|
image = utils.resize_to_batch_size(image, noise.shape[0])
|
|
|
|
out['c_concat'] = comfy.conds.CONDNoiseShape(image)
|
|
out['y'] = comfy.conds.CONDRegular(noise_level)
|
|
return out
|
|
|
|
class IP2P:
|
|
def concat_cond(self, **kwargs):
|
|
image = kwargs.get("concat_latent_image", None)
|
|
noise = kwargs.get("noise", None)
|
|
device = kwargs["device"]
|
|
|
|
if image is None:
|
|
image = torch.zeros_like(noise)
|
|
|
|
if image.shape[1:] != noise.shape[1:]:
|
|
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
|
|
|
image = utils.resize_to_batch_size(image, noise.shape[0])
|
|
return self.process_ip2p_image_in(image)
|
|
|
|
|
|
class SD15_instructpix2pix(IP2P, BaseModel):
|
|
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
|
|
super().__init__(model_config, model_type, device=device)
|
|
self.process_ip2p_image_in = lambda image: image
|
|
|
|
|
|
class SDXL_instructpix2pix(IP2P, SDXL):
|
|
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
|
|
super().__init__(model_config, model_type, device=device)
|
|
if model_type == ModelType.V_PREDICTION_EDM:
|
|
self.process_ip2p_image_in = lambda image: comfy.latent_formats.SDXL().process_in(image) #cosxl ip2p
|
|
else:
|
|
self.process_ip2p_image_in = lambda image: image #diffusers ip2p
|
|
|
|
|
|
class StableCascade_C(BaseModel):
|
|
def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None):
|
|
super().__init__(model_config, model_type, device=device, unet_model=StageC)
|
|
self.diffusion_model.eval().requires_grad_(False)
|
|
|
|
def extra_conds(self, **kwargs):
|
|
out = {}
|
|
clip_text_pooled = kwargs["pooled_output"]
|
|
if clip_text_pooled is not None:
|
|
out['clip_text_pooled'] = comfy.conds.CONDRegular(clip_text_pooled)
|
|
|
|
if "unclip_conditioning" in kwargs:
|
|
embeds = []
|
|
for unclip_cond in kwargs["unclip_conditioning"]:
|
|
weight = unclip_cond["strength"]
|
|
embeds.append(unclip_cond["clip_vision_output"].image_embeds.unsqueeze(0) * weight)
|
|
clip_img = torch.cat(embeds, dim=1)
|
|
else:
|
|
clip_img = torch.zeros((1, 1, 768))
|
|
out["clip_img"] = comfy.conds.CONDRegular(clip_img)
|
|
out["sca"] = comfy.conds.CONDRegular(torch.zeros((1,)))
|
|
out["crp"] = comfy.conds.CONDRegular(torch.zeros((1,)))
|
|
|
|
cross_attn = kwargs.get("cross_attn", None)
|
|
if cross_attn is not None:
|
|
out['clip_text'] = comfy.conds.CONDCrossAttn(cross_attn)
|
|
return out
|
|
|
|
|
|
class StableCascade_B(BaseModel):
|
|
def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None):
|
|
super().__init__(model_config, model_type, device=device, unet_model=StageB)
|
|
self.diffusion_model.eval().requires_grad_(False)
|
|
|
|
def extra_conds(self, **kwargs):
|
|
out = {}
|
|
noise = kwargs.get("noise", None)
|
|
|
|
clip_text_pooled = kwargs["pooled_output"]
|
|
if clip_text_pooled is not None:
|
|
out['clip'] = comfy.conds.CONDRegular(clip_text_pooled)
|
|
|
|
#size of prior doesn't really matter if zeros because it gets resized but I still want it to get batched
|
|
prior = kwargs.get("stable_cascade_prior", torch.zeros((1, 16, (noise.shape[2] * 4) // 42, (noise.shape[3] * 4) // 42), dtype=noise.dtype, layout=noise.layout, device=noise.device))
|
|
|
|
out["effnet"] = comfy.conds.CONDRegular(prior)
|
|
out["sca"] = comfy.conds.CONDRegular(torch.zeros((1,)))
|
|
return out
|
|
|
|
|
|
class SD3(BaseModel):
|
|
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
|
super().__init__(model_config, model_type, device=device, unet_model=OpenAISignatureMMDITWrapper)
|
|
|
|
def encode_adm(self, **kwargs):
|
|
return kwargs["pooled_output"]
|
|
|
|
def extra_conds(self, **kwargs):
|
|
out = super().extra_conds(**kwargs)
|
|
cross_attn = kwargs.get("cross_attn", None)
|
|
if cross_attn is not None:
|
|
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
|
return out
|
|
|
|
|
|
class AuraFlow(BaseModel):
|
|
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
|
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.aura.mmdit.MMDiT)
|
|
|
|
def extra_conds(self, **kwargs):
|
|
out = super().extra_conds(**kwargs)
|
|
cross_attn = kwargs.get("cross_attn", None)
|
|
if cross_attn is not None:
|
|
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
|
return out
|
|
|
|
|
|
class StableAudio1(BaseModel):
|
|
def __init__(self, model_config, seconds_start_embedder_weights, seconds_total_embedder_weights, model_type=ModelType.V_PREDICTION_CONTINUOUS, device=None):
|
|
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.audio.dit.AudioDiffusionTransformer)
|
|
self.seconds_start_embedder = comfy.ldm.audio.embedders.NumberConditioner(768, min_val=0, max_val=512)
|
|
self.seconds_total_embedder = comfy.ldm.audio.embedders.NumberConditioner(768, min_val=0, max_val=512)
|
|
self.seconds_start_embedder.load_state_dict(seconds_start_embedder_weights)
|
|
self.seconds_total_embedder.load_state_dict(seconds_total_embedder_weights)
|
|
|
|
def extra_conds(self, **kwargs):
|
|
out = {}
|
|
|
|
noise = kwargs.get("noise", None)
|
|
device = kwargs["device"]
|
|
|
|
seconds_start = kwargs.get("seconds_start", 0)
|
|
seconds_total = kwargs.get("seconds_total", int(noise.shape[-1] / 21.53))
|
|
|
|
seconds_start_embed = self.seconds_start_embedder([seconds_start])[0].to(device)
|
|
seconds_total_embed = self.seconds_total_embedder([seconds_total])[0].to(device)
|
|
|
|
global_embed = torch.cat([seconds_start_embed, seconds_total_embed], dim=-1).reshape((1, -1))
|
|
out['global_embed'] = comfy.conds.CONDRegular(global_embed)
|
|
|
|
cross_attn = kwargs.get("cross_attn", None)
|
|
if cross_attn is not None:
|
|
cross_attn = torch.cat([cross_attn.to(device), seconds_start_embed.repeat((cross_attn.shape[0], 1, 1)), seconds_total_embed.repeat((cross_attn.shape[0], 1, 1))], dim=1)
|
|
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
|
return out
|
|
|
|
def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None):
|
|
sd = super().state_dict_for_saving(clip_state_dict=clip_state_dict, vae_state_dict=vae_state_dict, clip_vision_state_dict=clip_vision_state_dict)
|
|
d = {"conditioner.conditioners.seconds_start.": self.seconds_start_embedder.state_dict(), "conditioner.conditioners.seconds_total.": self.seconds_total_embedder.state_dict()}
|
|
for k in d:
|
|
s = d[k]
|
|
for l in s:
|
|
sd["{}{}".format(k, l)] = s[l]
|
|
return sd
|
|
|
|
class HunyuanDiT(BaseModel):
|
|
def __init__(self, model_config, model_type=ModelType.V_PREDICTION, device=None):
|
|
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hydit.models.HunYuanDiT)
|
|
|
|
def extra_conds(self, **kwargs):
|
|
out = super().extra_conds(**kwargs)
|
|
cross_attn = kwargs.get("cross_attn", None)
|
|
if cross_attn is not None:
|
|
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
|
|
|
attention_mask = kwargs.get("attention_mask", None)
|
|
if attention_mask is not None:
|
|
out['text_embedding_mask'] = comfy.conds.CONDRegular(attention_mask)
|
|
|
|
conditioning_mt5xl = kwargs.get("conditioning_mt5xl", None)
|
|
if conditioning_mt5xl is not None:
|
|
out['encoder_hidden_states_t5'] = comfy.conds.CONDRegular(conditioning_mt5xl)
|
|
|
|
attention_mask_mt5xl = kwargs.get("attention_mask_mt5xl", None)
|
|
if attention_mask_mt5xl is not None:
|
|
out['text_embedding_mask_t5'] = comfy.conds.CONDRegular(attention_mask_mt5xl)
|
|
|
|
width = kwargs.get("width", 768)
|
|
height = kwargs.get("height", 768)
|
|
crop_w = kwargs.get("crop_w", 0)
|
|
crop_h = kwargs.get("crop_h", 0)
|
|
target_width = kwargs.get("target_width", width)
|
|
target_height = kwargs.get("target_height", height)
|
|
|
|
out['image_meta_size'] = comfy.conds.CONDRegular(torch.FloatTensor([[height, width, target_height, target_width, 0, 0]]))
|
|
return out
|
|
|
|
class Flux(BaseModel):
|
|
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
|
|
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.flux.model.Flux)
|
|
|
|
def concat_cond(self, **kwargs):
|
|
try:
|
|
#Handle Flux control loras dynamically changing the img_in weight.
|
|
num_channels = self.diffusion_model.img_in.weight.shape[1] // (self.diffusion_model.patch_size * self.diffusion_model.patch_size)
|
|
except:
|
|
#Some cases like tensorrt might not have the weights accessible
|
|
num_channels = self.model_config.unet_config["in_channels"]
|
|
|
|
out_channels = self.model_config.unet_config["out_channels"]
|
|
|
|
if num_channels <= out_channels:
|
|
return None
|
|
|
|
image = kwargs.get("concat_latent_image", None)
|
|
noise = kwargs.get("noise", None)
|
|
device = kwargs["device"]
|
|
|
|
if image is None:
|
|
image = torch.zeros_like(noise)
|
|
|
|
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
|
image = utils.resize_to_batch_size(image, noise.shape[0])
|
|
image = self.process_latent_in(image)
|
|
if num_channels <= out_channels * 2:
|
|
return image
|
|
|
|
#inpaint model
|
|
mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
|
|
if mask is None:
|
|
mask = torch.ones_like(noise)[:, :1]
|
|
|
|
mask = torch.mean(mask, dim=1, keepdim=True)
|
|
print(mask.shape)
|
|
mask = utils.common_upscale(mask.to(device), noise.shape[-1] * 8, noise.shape[-2] * 8, "bilinear", "center")
|
|
mask = mask.view(mask.shape[0], mask.shape[2] // 8, 8, mask.shape[3] // 8, 8).permute(0, 2, 4, 1, 3).reshape(mask.shape[0], -1, mask.shape[2] // 8, mask.shape[3] // 8)
|
|
mask = utils.resize_to_batch_size(mask, noise.shape[0])
|
|
return torch.cat((image, mask), dim=1)
|
|
|
|
def encode_adm(self, **kwargs):
|
|
return kwargs["pooled_output"]
|
|
|
|
def extra_conds(self, **kwargs):
|
|
out = super().extra_conds(**kwargs)
|
|
cross_attn = kwargs.get("cross_attn", None)
|
|
if cross_attn is not None:
|
|
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
|
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([kwargs.get("guidance", 3.5)]))
|
|
return out
|
|
|
|
class GenmoMochi(BaseModel):
|
|
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
|
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.genmo.joint_model.asymm_models_joint.AsymmDiTJoint)
|
|
|
|
def extra_conds(self, **kwargs):
|
|
out = super().extra_conds(**kwargs)
|
|
attention_mask = kwargs.get("attention_mask", None)
|
|
if attention_mask is not None:
|
|
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
|
|
out['num_tokens'] = comfy.conds.CONDConstant(max(1, torch.sum(attention_mask).item()))
|
|
cross_attn = kwargs.get("cross_attn", None)
|
|
if cross_attn is not None:
|
|
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
|
return out
|
|
|
|
class LTXV(BaseModel):
|
|
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
|
|
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lightricks.model.LTXVModel) #TODO
|
|
|
|
def extra_conds(self, **kwargs):
|
|
out = super().extra_conds(**kwargs)
|
|
attention_mask = kwargs.get("attention_mask", None)
|
|
if attention_mask is not None:
|
|
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
|
|
cross_attn = kwargs.get("cross_attn", None)
|
|
if cross_attn is not None:
|
|
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
|
|
|
guiding_latent = kwargs.get("guiding_latent", None)
|
|
if guiding_latent is not None:
|
|
out['guiding_latent'] = comfy.conds.CONDRegular(guiding_latent)
|
|
|
|
out['frame_rate'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", 25))
|
|
return out
|