Compare commits

...

24 Commits

Author SHA1 Message Date
Jedrzej Kosinski
770675c70a
Merge 7333281698 into 2ff3104f70 2025-01-10 09:19:52 -05:00
comfyanonymous
2ff3104f70 WIP support for Nvidia Cosmos 7B and 14B text to world (video) models. 2025-01-10 09:14:16 -05:00
comfyanonymous
129d8908f7 Add argument to skip the output reshaping in the attention functions. 2025-01-10 06:27:37 -05:00
Jedrzej Kosinski
7333281698 Clean up a typehint 2025-01-07 02:58:59 -06:00
Jedrzej Kosinski
3cd4c5cb0a Rename AddModelsHooks to AdditionalModelsHook, rename SetInjectionsHook to InjectionsHook (not yet implemented, but at least getting the naming figured out) 2025-01-07 02:22:49 -06:00
Jedrzej Kosinski
11c6d56037 Merge branch 'master' into hooks_part2 2025-01-07 01:01:53 -06:00
Jedrzej Kosinski
216fea15ee Made TransformerOptionsHook contribute to registered hooks properly, added some doc strings and removed a so-far unused variable 2025-01-07 00:59:18 -06:00
Jedrzej Kosinski
58bf8815c8 Add a get_injections function to ModelPatcher 2025-01-06 20:34:30 -06:00
Jedrzej Kosinski
1b38f5bf57 removed 4 whitespace lines to satisfy Ruff, 2025-01-06 17:11:12 -06:00
Jedrzej Kosinski
2724ac4a60 Merge branch 'master' into hooks_part2 2025-01-06 17:04:24 -06:00
Jedrzej Kosinski
f48f90e471 Make hook_scope functional for TransformerOptionsHook 2025-01-06 02:23:04 -06:00
Jedrzej Kosinski
6463c39ce0 Merge branch 'master' into hooks_part2 2025-01-06 01:28:26 -06:00
Jedrzej Kosinski
0a7e2ae787 Filter only registered hooks on self.conds in CFGGuider.sample 2025-01-06 01:04:29 -06:00
Jedrzej Kosinski
03a97b604a Fix performance of hooks when hooks are appended via Cond Pair Set Props nodes by properly caching between positive and negative conds, make hook_patches_backup behave as intended (in the case that something pre-registers WeightHooks on the ModelPatcher instead of registering it at sample time) 2025-01-06 01:03:59 -06:00
Jedrzej Kosinski
4446c86052 Made hook clone code sane, made clear ObjectPatchHook and SetInjectionsHook are not yet operational 2025-01-05 22:25:51 -06:00
Jedrzej Kosinski
8270ff312f Refactored 'registered' to be HookGroup instead of a list of Hooks, made AddModelsHook operational and compliant with should_register result, moved TransformerOptionsHook handling out of ModelPatcher.register_all_hook_patches, support patches in TransformerOptionsHook properly by casting any patches/wrappers/hooks to proper device at sample time 2025-01-05 21:07:02 -06:00
Jedrzej Kosinski
db2d7ad9ba Merge branch 'add_sample_sigmas' into hooks_part2 2025-01-05 15:45:13 -06:00
Jedrzej Kosinski
6620d86318 In inner_sample, change "sigmas" to "sampler_sigmas" in transformer_options to not conflict with the "sigmas" that will overwrite "sigmas" in _calc_cond_batch 2025-01-05 15:26:22 -06:00
Jedrzej Kosinski
111fd0cadf Refactored HookGroup to also store a dictionary of hooks separated by hook_type, modified necessary code to no longer need to manually separate out hooks by hook_type 2025-01-04 02:04:07 -06:00
Jedrzej Kosinski
776aa734e1 Refactor WrapperHook into TransformerOptionsHook, as there is no need to separate out Wrappers/Callbacks/Patches into different hook types (all affect transformer_options) 2025-01-04 01:02:21 -06:00
Jedrzej Kosinski
5a2ad032cb Cleaned up hooks.py, refactored Hook.should_register and add_hook_patches to use target_dict instead of target so that more information can be provided about the current execution environment if needed 2025-01-03 20:02:27 -06:00
Jedrzej Kosinski
d44295ef71 Merge branch 'master' into hooks_part2 2025-01-03 18:28:31 -06:00
Jedrzej Kosinski
bf21be066f Merge branch 'master' into hooks_part2 2024-12-30 14:16:22 -06:00
Jedrzej Kosinski
72bbf49349 Add 'sigmas' to transformer_options so that downstream code can know about the full scope of current sampling run, fix Hook Keyframes' guarantee_steps=1 inconsistent behavior with sampling split across different Sampling nodes/sampling runs by referencing 'sigmas' 2024-12-29 15:49:09 -06:00
22 changed files with 3821 additions and 227 deletions

View File

@ -16,91 +16,132 @@ import comfy.model_management
import comfy.patcher_extension
from node_helpers import conditioning_set_values
# #######################################################################################################
# Hooks explanation
# -------------------
# The purpose of hooks is to allow conds to influence sampling without the need for ComfyUI core code to
# make explicit special cases like it does for ControlNet and GLIGEN.
#
# This is necessary for nodes/features that are intended for use with masked or scheduled conds, or those
# that should run special code when a 'marked' cond is used in sampling.
# #######################################################################################################
class EnumHookMode(enum.Enum):
'''
Priority of hook memory optimization vs. speed, mostly related to WeightHooks.
MinVram: No caching will occur for any operations related to hooks.
MaxSpeed: Excess VRAM (and RAM, once VRAM is sufficiently depleted) will be used to cache hook weights when switching hook groups.
'''
MinVram = "minvram"
MaxSpeed = "maxspeed"
class EnumHookType(enum.Enum):
'''
Hook types, each of which has different expected behavior.
'''
Weight = "weight"
Patch = "patch"
ObjectPatch = "object_patch"
AddModels = "add_models"
Callbacks = "callbacks"
Wrappers = "wrappers"
SetInjections = "add_injections"
AdditionalModels = "add_models"
TransformerOptions = "transformer_options"
Injections = "add_injections"
class EnumWeightTarget(enum.Enum):
Model = "model"
Clip = "clip"
class EnumHookScope(enum.Enum):
'''
Determines if hook should be limited in its influence over sampling.
AllConditioning: hook will affect all conds used in sampling.
HookedOnly: hook will only affect the conds it was attached to.
'''
AllConditioning = "all_conditioning"
HookedOnly = "hooked_only"
class _HookRef:
pass
# NOTE: this is an example of how the should_register function should look
def default_should_register(hook: 'Hook', model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
def default_should_register(hook: Hook, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
'''Example for how custom_should_register function can look like.'''
return True
def create_target_dict(target: EnumWeightTarget=None, **kwargs) -> dict[str]:
'''Creates base dictionary for use with Hooks' target param.'''
d = {}
if target is not None:
d['target'] = target
d.update(kwargs)
return d
class Hook:
def __init__(self, hook_type: EnumHookType=None, hook_ref: _HookRef=None, hook_id: str=None,
hook_keyframe: 'HookKeyframeGroup'=None):
hook_keyframe: HookKeyframeGroup=None, hook_scope=EnumHookScope.AllConditioning):
self.hook_type = hook_type
'''Enum identifying the general class of this hook.'''
self.hook_ref = hook_ref if hook_ref else _HookRef()
'''Reference shared between hook clones that have the same value. Should NOT be modified.'''
self.hook_id = hook_id
'''Optional string ID to identify hook; useful if need to consolidate duplicates at registration time.'''
self.hook_keyframe = hook_keyframe if hook_keyframe else HookKeyframeGroup()
'''Keyframe storage that can be referenced to get strength for current sampling step.'''
self.hook_scope = hook_scope
'''Scope of where this hook should apply in terms of the conds used in sampling run.'''
self.custom_should_register = default_should_register
self.auto_apply_to_nonpositive = False
'''Can be overriden with a compatible function to decide if this hook should be registered without the need to override .should_register'''
@property
def strength(self):
return self.hook_keyframe.strength
def initialize_timesteps(self, model: 'BaseModel'):
def initialize_timesteps(self, model: BaseModel):
self.reset()
self.hook_keyframe.initialize_timesteps(model)
def reset(self):
self.hook_keyframe.reset()
def clone(self, subtype: Callable=None):
if subtype is None:
subtype = type(self)
c: Hook = subtype()
def clone(self):
c: Hook = self.__class__()
c.hook_type = self.hook_type
c.hook_ref = self.hook_ref
c.hook_id = self.hook_id
c.hook_keyframe = self.hook_keyframe
c.hook_scope = self.hook_scope
c.custom_should_register = self.custom_should_register
# TODO: make this do something
c.auto_apply_to_nonpositive = self.auto_apply_to_nonpositive
return c
def should_register(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
return self.custom_should_register(self, model, model_options, target, registered)
def should_register(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
return self.custom_should_register(self, model, model_options, target_dict, registered)
def add_hook_patches(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
raise NotImplementedError("add_hook_patches should be defined for Hook subclasses")
def on_apply(self, model: 'ModelPatcher', transformer_options: dict[str]):
pass
def on_unapply(self, model: 'ModelPatcher', transformer_options: dict[str]):
pass
def __eq__(self, other: 'Hook'):
def __eq__(self, other: Hook):
return self.__class__ == other.__class__ and self.hook_ref == other.hook_ref
def __hash__(self):
return hash(self.hook_ref)
class WeightHook(Hook):
'''
Hook responsible for tracking weights to be applied to some model/clip.
Note, value of hook_scope is ignored and is treated as HookedOnly.
'''
def __init__(self, strength_model=1.0, strength_clip=1.0):
super().__init__(hook_type=EnumHookType.Weight)
super().__init__(hook_type=EnumHookType.Weight, hook_scope=EnumHookScope.HookedOnly)
self.weights: dict = None
self.weights_clip: dict = None
self.need_weight_init = True
self._strength_model = strength_model
self._strength_clip = strength_clip
self.hook_scope = EnumHookScope.HookedOnly # this value does not matter for WeightHooks, just for docs
@property
def strength_model(self):
@ -110,36 +151,36 @@ class WeightHook(Hook):
def strength_clip(self):
return self._strength_clip * self.strength
def add_hook_patches(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
if not self.should_register(model, model_options, target, registered):
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
if not self.should_register(model, model_options, target_dict, registered):
return False
weights = None
if target == EnumWeightTarget.Model:
strength = self._strength_model
else:
target = target_dict.get('target', None)
if target == EnumWeightTarget.Clip:
strength = self._strength_clip
else:
strength = self._strength_model
if self.need_weight_init:
key_map = {}
if target == EnumWeightTarget.Model:
key_map = comfy.lora.model_lora_keys_unet(model.model, key_map)
else:
if target == EnumWeightTarget.Clip:
key_map = comfy.lora.model_lora_keys_clip(model.model, key_map)
else:
key_map = comfy.lora.model_lora_keys_unet(model.model, key_map)
weights = comfy.lora.load_lora(self.weights, key_map, log_missing=False)
else:
if target == EnumWeightTarget.Model:
weights = self.weights
else:
if target == EnumWeightTarget.Clip:
weights = self.weights_clip
else:
weights = self.weights
model.add_hook_patches(hook=self, patches=weights, strength_patch=strength)
registered.append(self)
registered.add(self)
return True
# TODO: add logs about any keys that were not applied
def clone(self, subtype: Callable=None):
if subtype is None:
subtype = type(self)
c: WeightHook = super().clone(subtype)
def clone(self):
c: WeightHook = super().clone()
c.weights = self.weights
c.weights_clip = self.weights_clip
c.need_weight_init = self.need_weight_init
@ -147,127 +188,158 @@ class WeightHook(Hook):
c._strength_clip = self._strength_clip
return c
class PatchHook(Hook):
def __init__(self):
super().__init__(hook_type=EnumHookType.Patch)
self.patches: dict = None
def clone(self, subtype: Callable=None):
if subtype is None:
subtype = type(self)
c: PatchHook = super().clone(subtype)
c.patches = self.patches
return c
# TODO: add functionality
class ObjectPatchHook(Hook):
def __init__(self):
def __init__(self, object_patches: dict[str]=None,
hook_scope=EnumHookScope.AllConditioning):
super().__init__(hook_type=EnumHookType.ObjectPatch)
self.object_patches: dict = None
self.object_patches = object_patches
self.hook_scope = hook_scope
def clone(self, subtype: Callable=None):
if subtype is None:
subtype = type(self)
c: ObjectPatchHook = super().clone(subtype)
def clone(self):
c: ObjectPatchHook = super().clone()
c.object_patches = self.object_patches
return c
# TODO: add functionality
class AddModelsHook(Hook):
def __init__(self, key: str=None, models: list['ModelPatcher']=None):
super().__init__(hook_type=EnumHookType.AddModels)
self.key = key
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
raise NotImplementedError("ObjectPatchHook is not supported yet in ComfyUI.")
class AdditionalModelsHook(Hook):
'''
Hook responsible for telling model management any additional models that should be loaded.
Note, value of hook_scope is ignored and is treated as AllConditioning.
'''
def __init__(self, models: list[ModelPatcher]=None, key: str=None):
super().__init__(hook_type=EnumHookType.AdditionalModels)
self.models = models
self.append_when_same = True
def clone(self, subtype: Callable=None):
if subtype is None:
subtype = type(self)
c: AddModelsHook = super().clone(subtype)
c.key = self.key
c.models = self.models.copy() if self.models else self.models
c.append_when_same = self.append_when_same
return c
# TODO: add functionality
class CallbackHook(Hook):
def __init__(self, key: str=None, callback: Callable=None):
super().__init__(hook_type=EnumHookType.Callbacks)
self.key = key
self.callback = callback
def clone(self, subtype: Callable=None):
if subtype is None:
subtype = type(self)
c: CallbackHook = super().clone(subtype)
def clone(self):
c: AdditionalModelsHook = super().clone()
c.models = self.models.copy() if self.models else self.models
c.key = self.key
c.callback = self.callback
return c
# TODO: add functionality
class WrapperHook(Hook):
def __init__(self, wrappers_dict: dict[str, dict[str, dict[str, list[Callable]]]]=None):
super().__init__(hook_type=EnumHookType.Wrappers)
self.wrappers_dict = wrappers_dict
def clone(self, subtype: Callable=None):
if subtype is None:
subtype = type(self)
c: WrapperHook = super().clone(subtype)
c.wrappers_dict = self.wrappers_dict
return c
def add_hook_patches(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
if not self.should_register(model, model_options, target, registered):
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
if not self.should_register(model, model_options, target_dict, registered):
return False
add_model_options = {"transformer_options": self.wrappers_dict}
comfy.patcher_extension.merge_nested_dicts(model_options, add_model_options, copy_dict1=False)
registered.append(self)
registered.add(self)
return True
class SetInjectionsHook(Hook):
def __init__(self, key: str=None, injections: list['PatcherInjection']=None):
super().__init__(hook_type=EnumHookType.SetInjections)
class TransformerOptionsHook(Hook):
'''
Hook responsible for adding wrappers, callbacks, patches, or anything else related to transformer_options.
'''
def __init__(self, transformers_dict: dict[str, dict[str, dict[str, list[Callable]]]]=None,
hook_scope=EnumHookScope.AllConditioning):
super().__init__(hook_type=EnumHookType.TransformerOptions)
self.transformers_dict = transformers_dict
self.hook_scope = hook_scope
self._skip_adding = False
'''Internal value used to avoid double load of transformer_options when hook_scope is AllConditioning.'''
def clone(self):
c: TransformerOptionsHook = super().clone()
c.transformers_dict = self.transformers_dict
c._skip_adding = self._skip_adding
return c
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
if not self.should_register(model, model_options, target_dict, registered):
return False
# NOTE: to_load_options will be used to manually load patches/wrappers/callbacks from hooks
self._skip_adding = False
if self.hook_scope == EnumHookScope.AllConditioning:
add_model_options = {"transformer_options": self.transformers_dict,
"to_load_options": self.transformers_dict}
# skip_adding if included in AllConditioning to avoid double loading
self._skip_adding = True
else:
add_model_options = {"to_load_options": self.transformers_dict}
registered.add(self)
comfy.patcher_extension.merge_nested_dicts(model_options, add_model_options, copy_dict1=False)
return True
def on_apply_hooks(self, model: ModelPatcher, transformer_options: dict[str]):
if not self._skip_adding:
comfy.patcher_extension.merge_nested_dicts(transformer_options, self.transformers_dict, copy_dict1=False)
WrapperHook = TransformerOptionsHook
'''Only here for backwards compatibility, WrapperHook is identical to TransformerOptionsHook.'''
class InjectionsHook(Hook):
def __init__(self, key: str=None, injections: list[PatcherInjection]=None,
hook_scope=EnumHookScope.AllConditioning):
super().__init__(hook_type=EnumHookType.Injections)
self.key = key
self.injections = injections
self.hook_scope = hook_scope
def clone(self, subtype: Callable=None):
if subtype is None:
subtype = type(self)
c: SetInjectionsHook = super().clone(subtype)
def clone(self):
c: InjectionsHook = super().clone()
c.key = self.key
c.injections = self.injections.copy() if self.injections else self.injections
return c
def add_hook_injections(self, model: 'ModelPatcher'):
# TODO: add functionality
pass
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
raise NotImplementedError("InjectionsHook is not supported yet in ComfyUI.")
class HookGroup:
'''
Stores groups of hooks, and allows them to be queried by type.
To prevent breaking their functionality, never modify the underlying self.hooks or self._hook_dict vars directly;
always use the provided functions on HookGroup.
'''
def __init__(self):
self.hooks: list[Hook] = []
self._hook_dict: dict[EnumHookType, list[Hook]] = {}
def __len__(self):
return len(self.hooks)
def add(self, hook: Hook):
if hook not in self.hooks:
self.hooks.append(hook)
self._hook_dict.setdefault(hook.hook_type, []).append(hook)
def remove(self, hook: Hook):
if hook in self.hooks:
self.hooks.remove(hook)
self._hook_dict[hook.hook_type].remove(hook)
def get_type(self, hook_type: EnumHookType):
return self._hook_dict.get(hook_type, [])
def contains(self, hook: Hook):
return hook in self.hooks
def is_subset_of(self, other: HookGroup):
self_hooks = set(self.hooks)
other_hooks = set(other.hooks)
return self_hooks.issubset(other_hooks)
def new_with_common_hooks(self, other: HookGroup):
c = HookGroup()
for hook in self.hooks:
if other.contains(hook):
c.add(hook.clone())
return c
def clone(self):
c = HookGroup()
for hook in self.hooks:
c.add(hook.clone())
return c
def clone_and_combine(self, other: 'HookGroup'):
def clone_and_combine(self, other: HookGroup):
c = self.clone()
if other is not None:
for hook in other.hooks:
c.add(hook.clone())
return c
def set_keyframes_on_hooks(self, hook_kf: 'HookKeyframeGroup'):
def set_keyframes_on_hooks(self, hook_kf: HookKeyframeGroup):
if hook_kf is None:
hook_kf = HookKeyframeGroup()
else:
@ -275,18 +347,11 @@ class HookGroup:
for hook in self.hooks:
hook.hook_keyframe = hook_kf
def get_dict_repr(self):
d: dict[EnumHookType, dict[Hook, None]] = {}
for hook in self.hooks:
with_type = d.setdefault(hook.hook_type, {})
with_type[hook] = None
return d
def get_hooks_for_clip_schedule(self):
scheduled_hooks: dict[WeightHook, list[tuple[tuple[float,float], HookKeyframe]]] = {}
for hook in self.hooks:
# only care about WeightHooks, for now
if hook.hook_type == EnumHookType.Weight:
for hook in self.get_type(EnumHookType.Weight):
hook: WeightHook
hook_schedule = []
# if no hook keyframes, assign default value
if len(hook.hook_keyframe.keyframes) == 0:
@ -336,7 +401,7 @@ class HookGroup:
hook.reset()
@staticmethod
def combine_all_hooks(hooks_list: list['HookGroup'], require_count=0) -> 'HookGroup':
def combine_all_hooks(hooks_list: list[HookGroup], require_count=0) -> HookGroup:
actual: list[HookGroup] = []
for group in hooks_list:
if group is not None:
@ -433,7 +498,7 @@ class HookKeyframeGroup:
c._set_first_as_current()
return c
def initialize_timesteps(self, model: 'BaseModel'):
def initialize_timesteps(self, model: BaseModel):
for keyframe in self.keyframes:
keyframe.start_t = model.model_sampling.percent_to_sigma(keyframe.start_percent)
@ -522,6 +587,17 @@ def get_sorted_list_via_attr(objects: list, attr: str) -> list:
sorted_list.extend(object_list)
return sorted_list
def create_transformer_options_from_hooks(model: ModelPatcher, hooks: HookGroup, transformer_options: dict[str]=None):
# if no hooks or is not a ModelPatcher for sampling, return empty dict
if hooks is None or model.is_clip:
return {}
if transformer_options is None:
transformer_options = {}
for hook in hooks.get_type(EnumHookType.TransformerOptions):
hook: TransformerOptionsHook
hook.on_apply_hooks(model, transformer_options)
return transformer_options
def create_hook_lora(lora: dict[str, torch.Tensor], strength_model: float, strength_clip: float):
hook_group = HookGroup()
hook = WeightHook(strength_model=strength_model, strength_clip=strength_clip)
@ -548,7 +624,7 @@ def create_hook_model_as_lora(weights_model, weights_clip, strength_model: float
hook.need_weight_init = False
return hook_group
def get_patch_weights_from_model(model: 'ModelPatcher', discard_model_sampling=True):
def get_patch_weights_from_model(model: ModelPatcher, discard_model_sampling=True):
if model is None:
return None
patches_model: dict[str, torch.Tensor] = model.model.state_dict()
@ -560,7 +636,7 @@ def get_patch_weights_from_model(model: 'ModelPatcher', discard_model_sampling=T
return patches_model
# NOTE: this function shows how to register weight hooks directly on the ModelPatchers
def load_hook_lora_for_models(model: 'ModelPatcher', clip: 'CLIP', lora: dict[str, torch.Tensor],
def load_hook_lora_for_models(model: ModelPatcher, clip: CLIP, lora: dict[str, torch.Tensor],
strength_model: float, strength_clip: float):
key_map = {}
if model is not None:
@ -612,24 +688,26 @@ def _combine_hooks_from_values(c_dict: dict[str, HookGroup], values: dict[str, H
else:
c_dict[hooks_key] = cache[hooks_tuple]
def conditioning_set_values_with_hooks(conditioning, values={}, append_hooks=True):
def conditioning_set_values_with_hooks(conditioning, values={}, append_hooks=True,
cache: dict[tuple[HookGroup, HookGroup], HookGroup]=None):
c = []
hooks_combine_cache: dict[tuple[HookGroup, HookGroup], HookGroup] = {}
if cache is None:
cache = {}
for t in conditioning:
n = [t[0], t[1].copy()]
for k in values:
if append_hooks and k == 'hooks':
_combine_hooks_from_values(n[1], values, hooks_combine_cache)
_combine_hooks_from_values(n[1], values, cache)
else:
n[1][k] = values[k]
c.append(n)
return c
def set_hooks_for_conditioning(cond, hooks: HookGroup, append_hooks=True):
def set_hooks_for_conditioning(cond, hooks: HookGroup, append_hooks=True, cache: dict[tuple[HookGroup, HookGroup], HookGroup]=None):
if hooks is None:
return cond
return conditioning_set_values_with_hooks(cond, {'hooks': hooks}, append_hooks=append_hooks)
return conditioning_set_values_with_hooks(cond, {'hooks': hooks}, append_hooks=append_hooks, cache=cache)
def set_timesteps_for_conditioning(cond, timestep_range: tuple[float,float]):
if timestep_range is None:
@ -664,9 +742,10 @@ def combine_with_new_conds(conds: list, new_conds: list):
def set_conds_props(conds: list, strength: float, set_cond_area: str,
mask: torch.Tensor=None, hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True):
final_conds = []
cache = {}
for c in conds:
# first, apply lora_hook to conditioning, if provided
c = set_hooks_for_conditioning(c, hooks, append_hooks=append_hooks)
c = set_hooks_for_conditioning(c, hooks, append_hooks=append_hooks, cache=cache)
# next, apply mask to conditioning
c = set_mask_for_conditioning(cond=c, mask=mask, strength=strength, set_cond_area=set_cond_area)
# apply timesteps, if present
@ -678,9 +757,10 @@ def set_conds_props(conds: list, strength: float, set_cond_area: str,
def set_conds_props_and_combine(conds: list, new_conds: list, strength: float=1.0, set_cond_area: str="default",
mask: torch.Tensor=None, hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True):
combined_conds = []
cache = {}
for c, masked_c in zip(conds, new_conds):
# first, apply lora_hook to new conditioning, if provided
masked_c = set_hooks_for_conditioning(masked_c, hooks, append_hooks=append_hooks)
masked_c = set_hooks_for_conditioning(masked_c, hooks, append_hooks=append_hooks, cache=cache)
# next, apply mask to new conditioning, if provided
masked_c = set_mask_for_conditioning(cond=masked_c, mask=mask, set_cond_area=set_cond_area, strength=strength)
# apply timesteps, if present
@ -692,9 +772,10 @@ def set_conds_props_and_combine(conds: list, new_conds: list, strength: float=1.
def set_default_conds_and_combine(conds: list, new_conds: list,
hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True):
combined_conds = []
cache = {}
for c, new_c in zip(conds, new_conds):
# first, apply lora_hook to new conditioning, if provided
new_c = set_hooks_for_conditioning(new_c, hooks, append_hooks=append_hooks)
new_c = set_hooks_for_conditioning(new_c, hooks, append_hooks=append_hooks, cache=cache)
# next, add default_cond key to cond so that during sampling, it can be identified
new_c = conditioning_set_values(new_c, {'default': True})
# apply timesteps, if present

View File

@ -382,3 +382,7 @@ class HunyuanVideo(LatentFormat):
]
latent_rgb_factors_bias = [ 0.0259, -0.0192, -0.0761]
class Cosmos1CV8x8x8(LatentFormat):
latent_channels = 16
latent_dimensions = 3

804
comfy/ldm/cosmos/blocks.py Normal file
View File

@ -0,0 +1,804 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Optional
import logging
import numpy as np
import torch
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
from torch import nn
from comfy.ldm.modules.diffusionmodules.mmdit import RMSNorm
from comfy.ldm.modules.attention import optimized_attention
def apply_rotary_pos_emb(
t: torch.Tensor,
freqs: torch.Tensor,
) -> torch.Tensor:
t_ = t.reshape(*t.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2).float()
t_out = freqs[..., 0] * t_[..., 0] + freqs[..., 1] * t_[..., 1]
t_out = t_out.movedim(-1, -2).reshape(*t.shape).type_as(t)
return t_out
def get_normalization(name: str, channels: int, weight_args={}):
if name == "I":
return nn.Identity()
elif name == "R":
return RMSNorm(channels, elementwise_affine=True, eps=1e-6, **weight_args)
else:
raise ValueError(f"Normalization {name} not found")
class BaseAttentionOp(nn.Module):
def __init__(self):
super().__init__()
class Attention(nn.Module):
"""
Generalized attention impl.
Allowing for both self-attention and cross-attention configurations depending on whether a `context_dim` is provided.
If `context_dim` is None, self-attention is assumed.
Parameters:
query_dim (int): Dimension of each query vector.
context_dim (int, optional): Dimension of each context vector. If None, self-attention is assumed.
heads (int, optional): Number of attention heads. Defaults to 8.
dim_head (int, optional): Dimension of each head. Defaults to 64.
dropout (float, optional): Dropout rate applied to the output of the attention block. Defaults to 0.0.
attn_op (BaseAttentionOp, optional): Custom attention operation to be used instead of the default.
qkv_bias (bool, optional): If True, adds a learnable bias to query, key, and value projections. Defaults to False.
out_bias (bool, optional): If True, adds a learnable bias to the output projection. Defaults to False.
qkv_norm (str, optional): A string representing normalization strategies for query, key, and value projections.
Defaults to "SSI".
qkv_norm_mode (str, optional): A string representing normalization mode for query, key, and value projections.
Defaults to 'per_head'. Only support 'per_head'.
Examples:
>>> attn = Attention(query_dim=128, context_dim=256, heads=4, dim_head=32, dropout=0.1)
>>> query = torch.randn(10, 128) # Batch size of 10
>>> context = torch.randn(10, 256) # Batch size of 10
>>> output = attn(query, context) # Perform the attention operation
Note:
https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
"""
def __init__(
self,
query_dim: int,
context_dim=None,
heads=8,
dim_head=64,
dropout=0.0,
attn_op: Optional[BaseAttentionOp] = None,
qkv_bias: bool = False,
out_bias: bool = False,
qkv_norm: str = "SSI",
qkv_norm_mode: str = "per_head",
backend: str = "transformer_engine",
qkv_format: str = "bshd",
weight_args={},
operations=None,
) -> None:
super().__init__()
self.is_selfattn = context_dim is None # self attention
inner_dim = dim_head * heads
context_dim = query_dim if context_dim is None else context_dim
self.heads = heads
self.dim_head = dim_head
self.qkv_norm_mode = qkv_norm_mode
self.qkv_format = qkv_format
if self.qkv_norm_mode == "per_head":
norm_dim = dim_head
else:
raise ValueError(f"Normalization mode {self.qkv_norm_mode} not found, only support 'per_head'")
self.backend = backend
self.to_q = nn.Sequential(
operations.Linear(query_dim, inner_dim, bias=qkv_bias, **weight_args),
get_normalization(qkv_norm[0], norm_dim),
)
self.to_k = nn.Sequential(
operations.Linear(context_dim, inner_dim, bias=qkv_bias, **weight_args),
get_normalization(qkv_norm[1], norm_dim),
)
self.to_v = nn.Sequential(
operations.Linear(context_dim, inner_dim, bias=qkv_bias, **weight_args),
get_normalization(qkv_norm[2], norm_dim),
)
self.to_out = nn.Sequential(
operations.Linear(inner_dim, query_dim, bias=out_bias, **weight_args),
nn.Dropout(dropout),
)
def cal_qkv(
self, x, context=None, mask=None, rope_emb=None, **kwargs
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
del kwargs
"""
self.to_q, self.to_k, self.to_v are nn.Sequential with projection + normalization layers.
Before 07/24/2024, these modules normalize across all heads.
After 07/24/2024, to support tensor parallelism and follow the common practice in the community,
we support to normalize per head.
To keep the checkpoint copatibility with the previous code,
we keep the nn.Sequential but call the projection and the normalization layers separately.
We use a flag `self.qkv_norm_mode` to control the normalization behavior.
The default value of `self.qkv_norm_mode` is "per_head", which means we normalize per head.
"""
if self.qkv_norm_mode == "per_head":
q = self.to_q[0](x)
context = x if context is None else context
k = self.to_k[0](context)
v = self.to_v[0](context)
q, k, v = map(
lambda t: rearrange(t, "s b (n c) -> b n s c", n=self.heads, c=self.dim_head),
(q, k, v),
)
else:
raise ValueError(f"Normalization mode {self.qkv_norm_mode} not found, only support 'per_head'")
q = self.to_q[1](q)
k = self.to_k[1](k)
v = self.to_v[1](v)
if self.is_selfattn and rope_emb is not None: # only apply to self-attention!
q = apply_rotary_pos_emb(q, rope_emb)
k = apply_rotary_pos_emb(k, rope_emb)
return q, k, v
def cal_attn(self, q, k, v, mask=None):
out = optimized_attention(q, k, v, self.heads, skip_reshape=True, mask=mask, skip_output_reshape=True)
out = rearrange(out, " b n s c -> s b (n c)")
return self.to_out(out)
def forward(
self,
x,
context=None,
mask=None,
rope_emb=None,
**kwargs,
):
"""
Args:
x (Tensor): The query tensor of shape [B, Mq, K]
context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None
"""
q, k, v = self.cal_qkv(x, context, mask, rope_emb=rope_emb, **kwargs)
return self.cal_attn(q, k, v, mask)
class FeedForward(nn.Module):
"""
Transformer FFN with optional gating
Parameters:
d_model (int): Dimensionality of input features.
d_ff (int): Dimensionality of the hidden layer.
dropout (float, optional): Dropout rate applied after the activation function. Defaults to 0.1.
activation (callable, optional): The activation function applied after the first linear layer.
Defaults to nn.ReLU().
is_gated (bool, optional): If set to True, incorporates gating mechanism to the feed-forward layer.
Defaults to False.
bias (bool, optional): If set to True, adds a bias to the linear layers. Defaults to True.
Example:
>>> ff = FeedForward(d_model=512, d_ff=2048)
>>> x = torch.randn(64, 10, 512) # Example input tensor
>>> output = ff(x)
>>> print(output.shape) # Expected shape: (64, 10, 512)
"""
def __init__(
self,
d_model: int,
d_ff: int,
dropout: float = 0.1,
activation=nn.ReLU(),
is_gated: bool = False,
bias: bool = False,
weight_args={},
operations=None,
) -> None:
super().__init__()
self.layer1 = operations.Linear(d_model, d_ff, bias=bias, **weight_args)
self.layer2 = operations.Linear(d_ff, d_model, bias=bias, **weight_args)
self.dropout = nn.Dropout(dropout)
self.activation = activation
self.is_gated = is_gated
if is_gated:
self.linear_gate = operations.Linear(d_model, d_ff, bias=False, **weight_args)
def forward(self, x: torch.Tensor):
g = self.activation(self.layer1(x))
if self.is_gated:
x = g * self.linear_gate(x)
else:
x = g
assert self.dropout.p == 0.0, "we skip dropout"
return self.layer2(x)
class GPT2FeedForward(FeedForward):
def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1, bias: bool = False, weight_args={}, operations=None):
super().__init__(
d_model=d_model,
d_ff=d_ff,
dropout=dropout,
activation=nn.GELU(),
is_gated=False,
bias=bias,
weight_args=weight_args,
operations=operations,
)
def forward(self, x: torch.Tensor):
assert self.dropout.p == 0.0, "we skip dropout"
x = self.layer1(x)
x = self.activation(x)
x = self.layer2(x)
return x
def modulate(x, shift, scale):
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
class Timesteps(nn.Module):
def __init__(self, num_channels):
super().__init__()
self.num_channels = num_channels
def forward(self, timesteps):
half_dim = self.num_channels // 2
exponent = -math.log(10000) * torch.arange(half_dim, dtype=torch.float32, device=timesteps.device)
exponent = exponent / (half_dim - 0.0)
emb = torch.exp(exponent)
emb = timesteps[:, None].float() * emb[None, :]
sin_emb = torch.sin(emb)
cos_emb = torch.cos(emb)
emb = torch.cat([cos_emb, sin_emb], dim=-1)
return emb
class TimestepEmbedding(nn.Module):
def __init__(self, in_features: int, out_features: int, use_adaln_lora: bool = False, weight_args={}, operations=None):
super().__init__()
logging.debug(
f"Using AdaLN LoRA Flag: {use_adaln_lora}. We enable bias if no AdaLN LoRA for backward compatibility."
)
self.linear_1 = operations.Linear(in_features, out_features, bias=not use_adaln_lora, **weight_args)
self.activation = nn.SiLU()
self.use_adaln_lora = use_adaln_lora
if use_adaln_lora:
self.linear_2 = operations.Linear(out_features, 3 * out_features, bias=False, **weight_args)
else:
self.linear_2 = operations.Linear(out_features, out_features, bias=True, **weight_args)
def forward(self, sample: torch.Tensor) -> torch.Tensor:
emb = self.linear_1(sample)
emb = self.activation(emb)
emb = self.linear_2(emb)
if self.use_adaln_lora:
adaln_lora_B_3D = emb
emb_B_D = sample
else:
emb_B_D = emb
adaln_lora_B_3D = None
return emb_B_D, adaln_lora_B_3D
class FourierFeatures(nn.Module):
"""
Implements a layer that generates Fourier features from input tensors, based on randomly sampled
frequencies and phases. This can help in learning high-frequency functions in low-dimensional problems.
[B] -> [B, D]
Parameters:
num_channels (int): The number of Fourier features to generate.
bandwidth (float, optional): The scaling factor for the frequency of the Fourier features. Defaults to 1.
normalize (bool, optional): If set to True, the outputs are scaled by sqrt(2), usually to normalize
the variance of the features. Defaults to False.
Example:
>>> layer = FourierFeatures(num_channels=256, bandwidth=0.5, normalize=True)
>>> x = torch.randn(10, 256) # Example input tensor
>>> output = layer(x)
>>> print(output.shape) # Expected shape: (10, 256)
"""
def __init__(self, num_channels, bandwidth=1, normalize=False):
super().__init__()
self.register_buffer("freqs", 2 * np.pi * bandwidth * torch.randn(num_channels), persistent=True)
self.register_buffer("phases", 2 * np.pi * torch.rand(num_channels), persistent=True)
self.gain = np.sqrt(2) if normalize else 1
def forward(self, x, gain: float = 1.0):
"""
Apply the Fourier feature transformation to the input tensor.
Args:
x (torch.Tensor): The input tensor.
gain (float, optional): An additional gain factor applied during the forward pass. Defaults to 1.
Returns:
torch.Tensor: The transformed tensor, with Fourier features applied.
"""
in_dtype = x.dtype
x = x.to(torch.float32).ger(self.freqs.to(torch.float32)).add(self.phases.to(torch.float32))
x = x.cos().mul(self.gain * gain).to(in_dtype)
return x
class PatchEmbed(nn.Module):
"""
PatchEmbed is a module for embedding patches from an input tensor by applying either 3D or 2D convolutional layers,
depending on the . This module can process inputs with temporal (video) and spatial (image) dimensions,
making it suitable for video and image processing tasks. It supports dividing the input into patches
and embedding each patch into a vector of size `out_channels`.
Parameters:
- spatial_patch_size (int): The size of each spatial patch.
- temporal_patch_size (int): The size of each temporal patch.
- in_channels (int): Number of input channels. Default: 3.
- out_channels (int): The dimension of the embedding vector for each patch. Default: 768.
- bias (bool): If True, adds a learnable bias to the output of the convolutional layers. Default: True.
"""
def __init__(
self,
spatial_patch_size,
temporal_patch_size,
in_channels=3,
out_channels=768,
bias=True,
weight_args={},
operations=None,
):
super().__init__()
self.spatial_patch_size = spatial_patch_size
self.temporal_patch_size = temporal_patch_size
self.proj = nn.Sequential(
Rearrange(
"b c (t r) (h m) (w n) -> b t h w (c r m n)",
r=temporal_patch_size,
m=spatial_patch_size,
n=spatial_patch_size,
),
operations.Linear(
in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size, out_channels, bias=bias, **weight_args
),
)
self.out = nn.Identity()
def forward(self, x):
"""
Forward pass of the PatchEmbed module.
Parameters:
- x (torch.Tensor): The input tensor of shape (B, C, T, H, W) where
B is the batch size,
C is the number of channels,
T is the temporal dimension,
H is the height, and
W is the width of the input.
Returns:
- torch.Tensor: The embedded patches as a tensor, with shape b t h w c.
"""
assert x.dim() == 5
_, _, T, H, W = x.shape
assert H % self.spatial_patch_size == 0 and W % self.spatial_patch_size == 0
assert T % self.temporal_patch_size == 0
x = self.proj(x)
return self.out(x)
class FinalLayer(nn.Module):
"""
The final layer of video DiT.
"""
def __init__(
self,
hidden_size,
spatial_patch_size,
temporal_patch_size,
out_channels,
use_adaln_lora: bool = False,
adaln_lora_dim: int = 256,
weight_args={},
operations=None,
):
super().__init__()
self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **weight_args)
self.linear = operations.Linear(
hidden_size, spatial_patch_size * spatial_patch_size * temporal_patch_size * out_channels, bias=False, **weight_args
)
self.hidden_size = hidden_size
self.n_adaln_chunks = 2
self.use_adaln_lora = use_adaln_lora
if use_adaln_lora:
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
operations.Linear(hidden_size, adaln_lora_dim, bias=False, **weight_args),
operations.Linear(adaln_lora_dim, self.n_adaln_chunks * hidden_size, bias=False, **weight_args),
)
else:
self.adaLN_modulation = nn.Sequential(
nn.SiLU(), operations.Linear(hidden_size, self.n_adaln_chunks * hidden_size, bias=False, **weight_args)
)
def forward(
self,
x_BT_HW_D,
emb_B_D,
adaln_lora_B_3D: Optional[torch.Tensor] = None,
):
if self.use_adaln_lora:
assert adaln_lora_B_3D is not None
shift_B_D, scale_B_D = (self.adaLN_modulation(emb_B_D) + adaln_lora_B_3D[:, : 2 * self.hidden_size]).chunk(
2, dim=1
)
else:
shift_B_D, scale_B_D = self.adaLN_modulation(emb_B_D).chunk(2, dim=1)
B = emb_B_D.shape[0]
T = x_BT_HW_D.shape[0] // B
shift_BT_D, scale_BT_D = repeat(shift_B_D, "b d -> (b t) d", t=T), repeat(scale_B_D, "b d -> (b t) d", t=T)
x_BT_HW_D = modulate(self.norm_final(x_BT_HW_D), shift_BT_D, scale_BT_D)
x_BT_HW_D = self.linear(x_BT_HW_D)
return x_BT_HW_D
class VideoAttn(nn.Module):
"""
Implements video attention with optional cross-attention capabilities.
This module processes video features while maintaining their spatio-temporal structure. It can perform
self-attention within the video features or cross-attention with external context features.
Parameters:
x_dim (int): Dimension of input feature vectors
context_dim (Optional[int]): Dimension of context features for cross-attention. None for self-attention
num_heads (int): Number of attention heads
bias (bool): Whether to include bias in attention projections. Default: False
qkv_norm_mode (str): Normalization mode for query/key/value projections. Must be "per_head". Default: "per_head"
x_format (str): Format of input tensor. Must be "BTHWD". Default: "BTHWD"
Input shape:
- x: (T, H, W, B, D) video features
- context (optional): (M, B, D) context features for cross-attention
where:
T: temporal dimension
H: height
W: width
B: batch size
D: feature dimension
M: context sequence length
"""
def __init__(
self,
x_dim: int,
context_dim: Optional[int],
num_heads: int,
bias: bool = False,
qkv_norm_mode: str = "per_head",
x_format: str = "BTHWD",
weight_args={},
operations=None,
) -> None:
super().__init__()
self.x_format = x_format
self.attn = Attention(
x_dim,
context_dim,
num_heads,
x_dim // num_heads,
qkv_bias=bias,
qkv_norm="RRI",
out_bias=bias,
qkv_norm_mode=qkv_norm_mode,
qkv_format="sbhd",
weight_args=weight_args,
operations=operations,
)
def forward(
self,
x: torch.Tensor,
context: Optional[torch.Tensor] = None,
crossattn_mask: Optional[torch.Tensor] = None,
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Forward pass for video attention.
Args:
x (Tensor): Input tensor of shape (B, T, H, W, D) or (T, H, W, B, D) representing batches of video data.
context (Tensor): Context tensor of shape (B, M, D) or (M, B, D),
where M is the sequence length of the context.
crossattn_mask (Optional[Tensor]): An optional mask for cross-attention mechanisms.
rope_emb_L_1_1_D (Optional[Tensor]):
Rotary positional embedding tensor of shape (L, 1, 1, D). L == THW for current video training.
Returns:
Tensor: The output tensor with applied attention, maintaining the input shape.
"""
x_T_H_W_B_D = x
context_M_B_D = context
T, H, W, B, D = x_T_H_W_B_D.shape
x_THW_B_D = rearrange(x_T_H_W_B_D, "t h w b d -> (t h w) b d")
x_THW_B_D = self.attn(
x_THW_B_D,
context_M_B_D,
crossattn_mask,
rope_emb=rope_emb_L_1_1_D,
)
x_T_H_W_B_D = rearrange(x_THW_B_D, "(t h w) b d -> t h w b d", h=H, w=W)
return x_T_H_W_B_D
def adaln_norm_state(norm_state, x, scale, shift):
normalized = norm_state(x)
return normalized * (1 + scale) + shift
class DITBuildingBlock(nn.Module):
"""
A building block for the DiT (Diffusion Transformer) architecture that supports different types of
attention and MLP operations with adaptive layer normalization.
Parameters:
block_type (str): Type of block - one of:
- "cross_attn"/"ca": Cross-attention
- "full_attn"/"fa": Full self-attention
- "mlp"/"ff": MLP/feedforward block
x_dim (int): Dimension of input features
context_dim (Optional[int]): Dimension of context features for cross-attention
num_heads (int): Number of attention heads
mlp_ratio (float): MLP hidden dimension multiplier. Default: 4.0
bias (bool): Whether to use bias in layers. Default: False
mlp_dropout (float): Dropout rate for MLP. Default: 0.0
qkv_norm_mode (str): QKV normalization mode. Default: "per_head"
x_format (str): Input tensor format. Default: "BTHWD"
use_adaln_lora (bool): Whether to use AdaLN-LoRA. Default: False
adaln_lora_dim (int): Dimension for AdaLN-LoRA. Default: 256
"""
def __init__(
self,
block_type: str,
x_dim: int,
context_dim: Optional[int],
num_heads: int,
mlp_ratio: float = 4.0,
bias: bool = False,
mlp_dropout: float = 0.0,
qkv_norm_mode: str = "per_head",
x_format: str = "BTHWD",
use_adaln_lora: bool = False,
adaln_lora_dim: int = 256,
weight_args={},
operations=None
) -> None:
block_type = block_type.lower()
super().__init__()
self.x_format = x_format
if block_type in ["cross_attn", "ca"]:
self.block = VideoAttn(
x_dim,
context_dim,
num_heads,
bias=bias,
qkv_norm_mode=qkv_norm_mode,
x_format=self.x_format,
weight_args=weight_args,
operations=operations,
)
elif block_type in ["full_attn", "fa"]:
self.block = VideoAttn(
x_dim, None, num_heads, bias=bias, qkv_norm_mode=qkv_norm_mode, x_format=self.x_format, weight_args=weight_args, operations=operations
)
elif block_type in ["mlp", "ff"]:
self.block = GPT2FeedForward(x_dim, int(x_dim * mlp_ratio), dropout=mlp_dropout, bias=bias, weight_args=weight_args, operations=operations)
else:
raise ValueError(f"Unknown block type: {block_type}")
self.block_type = block_type
self.use_adaln_lora = use_adaln_lora
self.norm_state = nn.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6)
self.n_adaln_chunks = 3
if use_adaln_lora:
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
operations.Linear(x_dim, adaln_lora_dim, bias=False, **weight_args),
operations.Linear(adaln_lora_dim, self.n_adaln_chunks * x_dim, bias=False, **weight_args),
)
else:
self.adaLN_modulation = nn.Sequential(nn.SiLU(), operations.Linear(x_dim, self.n_adaln_chunks * x_dim, bias=False, **weight_args))
def forward(
self,
x: torch.Tensor,
emb_B_D: torch.Tensor,
crossattn_emb: torch.Tensor,
crossattn_mask: Optional[torch.Tensor] = None,
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
adaln_lora_B_3D: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Forward pass for dynamically configured blocks with adaptive normalization.
Args:
x (Tensor): Input tensor of shape (B, T, H, W, D) or (T, H, W, B, D).
emb_B_D (Tensor): Embedding tensor for adaptive layer normalization modulation.
crossattn_emb (Tensor): Tensor for cross-attention blocks.
crossattn_mask (Optional[Tensor]): Optional mask for cross-attention.
rope_emb_L_1_1_D (Optional[Tensor]):
Rotary positional embedding tensor of shape (L, 1, 1, D). L == THW for current video training.
Returns:
Tensor: The output tensor after processing through the configured block and adaptive normalization.
"""
if self.use_adaln_lora:
shift_B_D, scale_B_D, gate_B_D = (self.adaLN_modulation(emb_B_D) + adaln_lora_B_3D).chunk(
self.n_adaln_chunks, dim=1
)
else:
shift_B_D, scale_B_D, gate_B_D = self.adaLN_modulation(emb_B_D).chunk(self.n_adaln_chunks, dim=1)
shift_1_1_1_B_D, scale_1_1_1_B_D, gate_1_1_1_B_D = (
shift_B_D.unsqueeze(0).unsqueeze(0).unsqueeze(0),
scale_B_D.unsqueeze(0).unsqueeze(0).unsqueeze(0),
gate_B_D.unsqueeze(0).unsqueeze(0).unsqueeze(0),
)
if self.block_type in ["mlp", "ff"]:
x = x + gate_1_1_1_B_D * self.block(
adaln_norm_state(self.norm_state, x, scale_1_1_1_B_D, shift_1_1_1_B_D),
)
elif self.block_type in ["full_attn", "fa"]:
x = x + gate_1_1_1_B_D * self.block(
adaln_norm_state(self.norm_state, x, scale_1_1_1_B_D, shift_1_1_1_B_D),
context=None,
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
)
elif self.block_type in ["cross_attn", "ca"]:
x = x + gate_1_1_1_B_D * self.block(
adaln_norm_state(self.norm_state, x, scale_1_1_1_B_D, shift_1_1_1_B_D),
context=crossattn_emb,
crossattn_mask=crossattn_mask,
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
)
else:
raise ValueError(f"Unknown block type: {self.block_type}")
return x
class GeneralDITTransformerBlock(nn.Module):
"""
A wrapper module that manages a sequence of DITBuildingBlocks to form a complete transformer layer.
Each block in the sequence is specified by a block configuration string.
Parameters:
x_dim (int): Dimension of input features
context_dim (int): Dimension of context features for cross-attention blocks
num_heads (int): Number of attention heads
block_config (str): String specifying block sequence (e.g. "ca-fa-mlp" for cross-attention,
full-attention, then MLP)
mlp_ratio (float): MLP hidden dimension multiplier. Default: 4.0
x_format (str): Input tensor format. Default: "BTHWD"
use_adaln_lora (bool): Whether to use AdaLN-LoRA. Default: False
adaln_lora_dim (int): Dimension for AdaLN-LoRA. Default: 256
The block_config string uses "-" to separate block types:
- "ca"/"cross_attn": Cross-attention block
- "fa"/"full_attn": Full self-attention block
- "mlp"/"ff": MLP/feedforward block
Example:
block_config = "ca-fa-mlp" creates a sequence of:
1. Cross-attention block
2. Full self-attention block
3. MLP block
"""
def __init__(
self,
x_dim: int,
context_dim: int,
num_heads: int,
block_config: str,
mlp_ratio: float = 4.0,
x_format: str = "BTHWD",
use_adaln_lora: bool = False,
adaln_lora_dim: int = 256,
weight_args={},
operations=None
):
super().__init__()
self.blocks = nn.ModuleList()
self.x_format = x_format
for block_type in block_config.split("-"):
self.blocks.append(
DITBuildingBlock(
block_type,
x_dim,
context_dim,
num_heads,
mlp_ratio,
x_format=self.x_format,
use_adaln_lora=use_adaln_lora,
adaln_lora_dim=adaln_lora_dim,
weight_args=weight_args,
operations=operations,
)
)
def forward(
self,
x: torch.Tensor,
emb_B_D: torch.Tensor,
crossattn_emb: torch.Tensor,
crossattn_mask: Optional[torch.Tensor] = None,
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
adaln_lora_B_3D: Optional[torch.Tensor] = None,
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if extra_per_block_pos_emb is not None:
x = x + extra_per_block_pos_emb
for block in self.blocks:
x = block(
x,
emb_B_D,
crossattn_emb,
crossattn_mask,
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
adaln_lora_B_3D=adaln_lora_B_3D,
)
return x

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,355 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The patcher and unpatcher implementation for 2D and 3D data.
The idea of Haar wavelet is to compute LL, LH, HL, HH component as two 1D convolutions.
One on the rows and one on the columns.
For example, in 1D signal, we have [a, b], then the low-freq compoenent is [a + b] / 2 and high-freq is [a - b] / 2.
We can use a 1D convolution with kernel [1, 1] and stride 2 to represent the L component.
For H component, we can use a 1D convolution with kernel [1, -1] and stride 2.
Although in principle, we typically only do additional Haar wavelet over the LL component. But here we do it for all
as we need to support downsampling for more than 2x.
For example, 4x downsampling can be done by 2x Haar and additional 2x Haar, and the shape would be.
[3, 256, 256] -> [12, 128, 128] -> [48, 64, 64]
"""
import torch
import torch.nn.functional as F
from einops import rearrange
_WAVELETS = {
"haar": torch.tensor([0.7071067811865476, 0.7071067811865476]),
"rearrange": torch.tensor([1.0, 1.0]),
}
_PERSISTENT = False
class Patcher(torch.nn.Module):
"""A module to convert image tensors into patches using torch operations.
The main difference from `class Patching` is that this module implements
all operations using torch, rather than python or numpy, for efficiency purpose.
It's bit-wise identical to the Patching module outputs, with the added
benefit of being torch.jit scriptable.
"""
def __init__(self, patch_size=1, patch_method="haar"):
super().__init__()
self.patch_size = patch_size
self.patch_method = patch_method
self.register_buffer(
"wavelets", _WAVELETS[patch_method], persistent=_PERSISTENT
)
self.range = range(int(torch.log2(torch.tensor(self.patch_size)).item()))
self.register_buffer(
"_arange",
torch.arange(_WAVELETS[patch_method].shape[0]),
persistent=_PERSISTENT,
)
for param in self.parameters():
param.requires_grad = False
def forward(self, x):
if self.patch_method == "haar":
return self._haar(x)
elif self.patch_method == "rearrange":
return self._arrange(x)
else:
raise ValueError("Unknown patch method: " + self.patch_method)
def _dwt(self, x, mode="reflect", rescale=False):
dtype = x.dtype
h = self.wavelets.to(device=x.device)
n = h.shape[0]
g = x.shape[1]
hl = h.flip(0).reshape(1, 1, -1).repeat(g, 1, 1)
hh = (h * ((-1) ** self._arange.to(device=x.device))).reshape(1, 1, -1).repeat(g, 1, 1)
hh = hh.to(dtype=dtype)
hl = hl.to(dtype=dtype)
x = F.pad(x, pad=(n - 2, n - 1, n - 2, n - 1), mode=mode).to(dtype)
xl = F.conv2d(x, hl.unsqueeze(2), groups=g, stride=(1, 2))
xh = F.conv2d(x, hh.unsqueeze(2), groups=g, stride=(1, 2))
xll = F.conv2d(xl, hl.unsqueeze(3), groups=g, stride=(2, 1))
xlh = F.conv2d(xl, hh.unsqueeze(3), groups=g, stride=(2, 1))
xhl = F.conv2d(xh, hl.unsqueeze(3), groups=g, stride=(2, 1))
xhh = F.conv2d(xh, hh.unsqueeze(3), groups=g, stride=(2, 1))
out = torch.cat([xll, xlh, xhl, xhh], dim=1)
if rescale:
out = out / 2
return out
def _haar(self, x):
for _ in self.range:
x = self._dwt(x, rescale=True)
return x
def _arrange(self, x):
x = rearrange(
x,
"b c (h p1) (w p2) -> b (c p1 p2) h w",
p1=self.patch_size,
p2=self.patch_size,
).contiguous()
return x
class Patcher3D(Patcher):
"""A 3D discrete wavelet transform for video data, expects 5D tensor, i.e. a batch of videos."""
def __init__(self, patch_size=1, patch_method="haar"):
super().__init__(patch_method=patch_method, patch_size=patch_size)
self.register_buffer(
"patch_size_buffer",
patch_size * torch.ones([1], dtype=torch.int32),
persistent=_PERSISTENT,
)
def _dwt(self, x, wavelet, mode="reflect", rescale=False):
dtype = x.dtype
h = self.wavelets.to(device=x.device)
n = h.shape[0]
g = x.shape[1]
hl = h.flip(0).reshape(1, 1, -1).repeat(g, 1, 1)
hh = (h * ((-1) ** self._arange.to(device=x.device))).reshape(1, 1, -1).repeat(g, 1, 1)
hh = hh.to(dtype=dtype)
hl = hl.to(dtype=dtype)
# Handles temporal axis.
x = F.pad(
x, pad=(max(0, n - 2), n - 1, n - 2, n - 1, n - 2, n - 1), mode=mode
).to(dtype)
xl = F.conv3d(x, hl.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1))
xh = F.conv3d(x, hh.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1))
# Handles spatial axes.
xll = F.conv3d(xl, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1))
xlh = F.conv3d(xl, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1))
xhl = F.conv3d(xh, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1))
xhh = F.conv3d(xh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1))
xlll = F.conv3d(xll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
xllh = F.conv3d(xll, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
xlhl = F.conv3d(xlh, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
xlhh = F.conv3d(xlh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
xhll = F.conv3d(xhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
xhlh = F.conv3d(xhl, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
xhhl = F.conv3d(xhh, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
xhhh = F.conv3d(xhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
out = torch.cat([xlll, xllh, xlhl, xlhh, xhll, xhlh, xhhl, xhhh], dim=1)
if rescale:
out = out / (2 * torch.sqrt(torch.tensor(2.0)))
return out
def _haar(self, x):
xi, xv = torch.split(x, [1, x.shape[2] - 1], dim=2)
x = torch.cat([xi.repeat_interleave(self.patch_size, dim=2), xv], dim=2)
for _ in self.range:
x = self._dwt(x, "haar", rescale=True)
return x
def _arrange(self, x):
xi, xv = torch.split(x, [1, x.shape[2] - 1], dim=2)
x = torch.cat([xi.repeat_interleave(self.patch_size, dim=2), xv], dim=2)
x = rearrange(
x,
"b c (t p1) (h p2) (w p3) -> b (c p1 p2 p3) t h w",
p1=self.patch_size,
p2=self.patch_size,
p3=self.patch_size,
).contiguous()
return x
class UnPatcher(torch.nn.Module):
"""A module to convert patches into image tensorsusing torch operations.
The main difference from `class Unpatching` is that this module implements
all operations using torch, rather than python or numpy, for efficiency purpose.
It's bit-wise identical to the Unpatching module outputs, with the added
benefit of being torch.jit scriptable.
"""
def __init__(self, patch_size=1, patch_method="haar"):
super().__init__()
self.patch_size = patch_size
self.patch_method = patch_method
self.register_buffer(
"wavelets", _WAVELETS[patch_method], persistent=_PERSISTENT
)
self.range = range(int(torch.log2(torch.tensor(self.patch_size)).item()))
self.register_buffer(
"_arange",
torch.arange(_WAVELETS[patch_method].shape[0]),
persistent=_PERSISTENT,
)
for param in self.parameters():
param.requires_grad = False
def forward(self, x):
if self.patch_method == "haar":
return self._ihaar(x)
elif self.patch_method == "rearrange":
return self._iarrange(x)
else:
raise ValueError("Unknown patch method: " + self.patch_method)
def _idwt(self, x, wavelet="haar", mode="reflect", rescale=False):
dtype = x.dtype
h = self.wavelets.to(device=x.device)
n = h.shape[0]
g = x.shape[1] // 4
hl = h.flip([0]).reshape(1, 1, -1).repeat([g, 1, 1])
hh = (h * ((-1) ** self._arange.to(device=x.device))).reshape(1, 1, -1).repeat(g, 1, 1)
hh = hh.to(dtype=dtype)
hl = hl.to(dtype=dtype)
xll, xlh, xhl, xhh = torch.chunk(x.to(dtype), 4, dim=1)
# Inverse transform.
yl = torch.nn.functional.conv_transpose2d(
xll, hl.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0)
)
yl += torch.nn.functional.conv_transpose2d(
xlh, hh.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0)
)
yh = torch.nn.functional.conv_transpose2d(
xhl, hl.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0)
)
yh += torch.nn.functional.conv_transpose2d(
xhh, hh.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0)
)
y = torch.nn.functional.conv_transpose2d(
yl, hl.unsqueeze(2), groups=g, stride=(1, 2), padding=(0, n - 2)
)
y += torch.nn.functional.conv_transpose2d(
yh, hh.unsqueeze(2), groups=g, stride=(1, 2), padding=(0, n - 2)
)
if rescale:
y = y * 2
return y
def _ihaar(self, x):
for _ in self.range:
x = self._idwt(x, "haar", rescale=True)
return x
def _iarrange(self, x):
x = rearrange(
x,
"b (c p1 p2) h w -> b c (h p1) (w p2)",
p1=self.patch_size,
p2=self.patch_size,
)
return x
class UnPatcher3D(UnPatcher):
"""A 3D inverse discrete wavelet transform for video wavelet decompositions."""
def __init__(self, patch_size=1, patch_method="haar"):
super().__init__(patch_method=patch_method, patch_size=patch_size)
def _idwt(self, x, wavelet="haar", mode="reflect", rescale=False):
dtype = x.dtype
h = self.wavelets.to(device=x.device)
g = x.shape[1] // 8 # split into 8 spatio-temporal filtered tesnors.
hl = h.flip([0]).reshape(1, 1, -1).repeat([g, 1, 1])
hh = (h * ((-1) ** self._arange.to(device=x.device))).reshape(1, 1, -1).repeat(g, 1, 1)
hl = hl.to(dtype=dtype)
hh = hh.to(dtype=dtype)
xlll, xllh, xlhl, xlhh, xhll, xhlh, xhhl, xhhh = torch.chunk(x, 8, dim=1)
# Height height transposed convolutions.
xll = F.conv_transpose3d(
xlll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
)
xll += F.conv_transpose3d(
xllh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
)
xlh = F.conv_transpose3d(
xlhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
)
xlh += F.conv_transpose3d(
xlhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
)
xhl = F.conv_transpose3d(
xhll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
)
xhl += F.conv_transpose3d(
xhlh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
)
xhh = F.conv_transpose3d(
xhhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
)
xhh += F.conv_transpose3d(
xhhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
)
# Handles width transposed convolutions.
xl = F.conv_transpose3d(
xll, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)
)
xl += F.conv_transpose3d(
xlh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)
)
xh = F.conv_transpose3d(
xhl, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)
)
xh += F.conv_transpose3d(
xhh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)
)
# Handles time axis transposed convolutions.
x = F.conv_transpose3d(
xl, hl.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1)
)
x += F.conv_transpose3d(
xh, hh.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1)
)
if rescale:
x = x * (2 * torch.sqrt(torch.tensor(2.0)))
return x
def _ihaar(self, x):
for _ in self.range:
x = self._idwt(x, "haar", rescale=True)
x = x[:, :, self.patch_size - 1 :, ...]
return x
def _iarrange(self, x):
x = rearrange(
x,
"b (c p1 p2 p3) t h w -> b c (t p1) (h p2) (w p3)",
p1=self.patch_size,
p2=self.patch_size,
p3=self.patch_size,
)
x = x[:, :, self.patch_size - 1 :, ...]
return x

View File

@ -0,0 +1,120 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Shared utilities for the networks module."""
from typing import Any
import torch
from einops import pack, rearrange, unpack
import comfy.ops
ops = comfy.ops.disable_weight_init
def time2batch(x: torch.Tensor) -> tuple[torch.Tensor, int]:
batch_size = x.shape[0]
return rearrange(x, "b c t h w -> (b t) c h w"), batch_size
def batch2time(x: torch.Tensor, batch_size: int) -> torch.Tensor:
return rearrange(x, "(b t) c h w -> b c t h w", b=batch_size)
def space2batch(x: torch.Tensor) -> tuple[torch.Tensor, int]:
batch_size, height = x.shape[0], x.shape[-2]
return rearrange(x, "b c t h w -> (b h w) c t"), batch_size, height
def batch2space(x: torch.Tensor, batch_size: int, height: int) -> torch.Tensor:
return rearrange(x, "(b h w) c t -> b c t h w", b=batch_size, h=height)
def cast_tuple(t: Any, length: int = 1) -> Any:
return t if isinstance(t, tuple) else ((t,) * length)
def replication_pad(x):
return torch.cat([x[:, :, :1, ...], x], dim=2)
def divisible_by(num: int, den: int) -> bool:
return (num % den) == 0
def is_odd(n: int) -> bool:
return not divisible_by(n, 2)
def nonlinearity(x):
return x * torch.sigmoid(x)
def Normalize(in_channels, num_groups=32):
return ops.GroupNorm(
num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
)
class CausalNormalize(torch.nn.Module):
def __init__(self, in_channels, num_groups=1):
super().__init__()
self.norm = ops.GroupNorm(
num_groups=num_groups,
num_channels=in_channels,
eps=1e-6,
affine=True,
)
self.num_groups = num_groups
def forward(self, x):
# if num_groups !=1, we apply a spatio-temporal groupnorm for backward compatibility purpose.
# All new models should use num_groups=1, otherwise causality is not guaranteed.
if self.num_groups == 1:
x, batch_size = time2batch(x)
return batch2time(self.norm(x), batch_size)
return self.norm(x)
def exists(v):
return v is not None
def default(*args):
for arg in args:
if exists(arg):
return arg
return None
def pack_one(t, pattern):
return pack([t], pattern)
def unpack_one(t, ps, pattern):
return unpack(t, ps, pattern)[0]
def round_ste(z: torch.Tensor) -> torch.Tensor:
"""Round with straight through gradients."""
zhat = z.round()
return z + (zhat - z).detach()
def log(t, eps=1e-5):
return t.clamp(min=eps).log()
def entropy(prob):
return (-prob * log(prob)).sum(dim=-1)

510
comfy/ldm/cosmos/model.py Normal file
View File

@ -0,0 +1,510 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
A general implementation of adaln-modulated VIT-like~(DiT) transformer for video processing.
"""
from typing import Optional, Tuple
import torch
from einops import rearrange
from torch import nn
from torchvision import transforms
from enum import Enum
import logging
from comfy.ldm.modules.diffusionmodules.mmdit import RMSNorm
from .blocks import (
FinalLayer,
GeneralDITTransformerBlock,
PatchEmbed,
TimestepEmbedding,
Timesteps,
)
from .position_embedding import LearnablePosEmbAxis, VideoRopePosition3DEmb
class DataType(Enum):
IMAGE = "image"
VIDEO = "video"
class GeneralDIT(nn.Module):
"""
A general implementation of adaln-modulated VIT-like~(DiT) transformer for video processing.
Args:
max_img_h (int): Maximum height of the input images.
max_img_w (int): Maximum width of the input images.
max_frames (int): Maximum number of frames in the video sequence.
in_channels (int): Number of input channels (e.g., RGB channels for color images).
out_channels (int): Number of output channels.
patch_spatial (tuple): Spatial resolution of patches for input processing.
patch_temporal (int): Temporal resolution of patches for input processing.
concat_padding_mask (bool): If True, includes a mask channel in the input to handle padding.
block_config (str): Configuration of the transformer block. See Notes for supported block types.
model_channels (int): Base number of channels used throughout the model.
num_blocks (int): Number of transformer blocks.
num_heads (int): Number of heads in the multi-head attention layers.
mlp_ratio (float): Expansion ratio for MLP blocks.
block_x_format (str): Format of input tensor for transformer blocks ('BTHWD' or 'THWBD').
crossattn_emb_channels (int): Number of embedding channels for cross-attention.
use_cross_attn_mask (bool): Whether to use mask in cross-attention.
pos_emb_cls (str): Type of positional embeddings.
pos_emb_learnable (bool): Whether positional embeddings are learnable.
pos_emb_interpolation (str): Method for interpolating positional embeddings.
affline_emb_norm (bool): Whether to normalize affine embeddings.
use_adaln_lora (bool): Whether to use AdaLN-LoRA.
adaln_lora_dim (int): Dimension for AdaLN-LoRA.
rope_h_extrapolation_ratio (float): Height extrapolation ratio for RoPE.
rope_w_extrapolation_ratio (float): Width extrapolation ratio for RoPE.
rope_t_extrapolation_ratio (float): Temporal extrapolation ratio for RoPE.
extra_per_block_abs_pos_emb (bool): Whether to use extra per-block absolute positional embeddings.
extra_per_block_abs_pos_emb_type (str): Type of extra per-block positional embeddings.
extra_h_extrapolation_ratio (float): Height extrapolation ratio for extra embeddings.
extra_w_extrapolation_ratio (float): Width extrapolation ratio for extra embeddings.
extra_t_extrapolation_ratio (float): Temporal extrapolation ratio for extra embeddings.
Notes:
Supported block types in block_config:
* cross_attn, ca: Cross attention
* full_attn: Full attention on all flattened tokens
* mlp, ff: Feed forward block
"""
def __init__(
self,
max_img_h: int,
max_img_w: int,
max_frames: int,
in_channels: int,
out_channels: int,
patch_spatial: tuple,
patch_temporal: int,
concat_padding_mask: bool = True,
# attention settings
block_config: str = "FA-CA-MLP",
model_channels: int = 768,
num_blocks: int = 10,
num_heads: int = 16,
mlp_ratio: float = 4.0,
block_x_format: str = "BTHWD",
# cross attention settings
crossattn_emb_channels: int = 1024,
use_cross_attn_mask: bool = False,
# positional embedding settings
pos_emb_cls: str = "sincos",
pos_emb_learnable: bool = False,
pos_emb_interpolation: str = "crop",
affline_emb_norm: bool = False, # whether or not to normalize the affine embedding
use_adaln_lora: bool = False,
adaln_lora_dim: int = 256,
rope_h_extrapolation_ratio: float = 1.0,
rope_w_extrapolation_ratio: float = 1.0,
rope_t_extrapolation_ratio: float = 1.0,
extra_per_block_abs_pos_emb: bool = False,
extra_per_block_abs_pos_emb_type: str = "sincos",
extra_h_extrapolation_ratio: float = 1.0,
extra_w_extrapolation_ratio: float = 1.0,
extra_t_extrapolation_ratio: float = 1.0,
image_model=None,
device=None,
dtype=None,
operations=None,
) -> None:
super().__init__()
self.max_img_h = max_img_h
self.max_img_w = max_img_w
self.max_frames = max_frames
self.in_channels = in_channels
self.out_channels = out_channels
self.patch_spatial = patch_spatial
self.patch_temporal = patch_temporal
self.num_heads = num_heads
self.num_blocks = num_blocks
self.model_channels = model_channels
self.use_cross_attn_mask = use_cross_attn_mask
self.concat_padding_mask = concat_padding_mask
# positional embedding settings
self.pos_emb_cls = pos_emb_cls
self.pos_emb_learnable = pos_emb_learnable
self.pos_emb_interpolation = pos_emb_interpolation
self.affline_emb_norm = affline_emb_norm
self.rope_h_extrapolation_ratio = rope_h_extrapolation_ratio
self.rope_w_extrapolation_ratio = rope_w_extrapolation_ratio
self.rope_t_extrapolation_ratio = rope_t_extrapolation_ratio
self.extra_per_block_abs_pos_emb = extra_per_block_abs_pos_emb
self.extra_per_block_abs_pos_emb_type = extra_per_block_abs_pos_emb_type.lower()
self.extra_h_extrapolation_ratio = extra_h_extrapolation_ratio
self.extra_w_extrapolation_ratio = extra_w_extrapolation_ratio
self.extra_t_extrapolation_ratio = extra_t_extrapolation_ratio
self.dtype = dtype
weight_args = {"device": device, "dtype": dtype}
in_channels = in_channels + 1 if concat_padding_mask else in_channels
self.x_embedder = PatchEmbed(
spatial_patch_size=patch_spatial,
temporal_patch_size=patch_temporal,
in_channels=in_channels,
out_channels=model_channels,
bias=False,
weight_args=weight_args,
operations=operations,
)
self.build_pos_embed(device=device)
self.block_x_format = block_x_format
self.use_adaln_lora = use_adaln_lora
self.adaln_lora_dim = adaln_lora_dim
self.t_embedder = nn.ModuleList(
[Timesteps(model_channels),
TimestepEmbedding(model_channels, model_channels, use_adaln_lora=use_adaln_lora, weight_args=weight_args, operations=operations),]
)
self.blocks = nn.ModuleDict()
for idx in range(num_blocks):
self.blocks[f"block{idx}"] = GeneralDITTransformerBlock(
x_dim=model_channels,
context_dim=crossattn_emb_channels,
num_heads=num_heads,
block_config=block_config,
mlp_ratio=mlp_ratio,
x_format=self.block_x_format,
use_adaln_lora=use_adaln_lora,
adaln_lora_dim=adaln_lora_dim,
weight_args=weight_args,
operations=operations,
)
if self.affline_emb_norm:
logging.debug("Building affine embedding normalization layer")
self.affline_norm = RMSNorm(model_channels, elementwise_affine=True, eps=1e-6)
else:
self.affline_norm = nn.Identity()
self.final_layer = FinalLayer(
hidden_size=self.model_channels,
spatial_patch_size=self.patch_spatial,
temporal_patch_size=self.patch_temporal,
out_channels=self.out_channels,
use_adaln_lora=self.use_adaln_lora,
adaln_lora_dim=self.adaln_lora_dim,
weight_args=weight_args,
operations=operations,
)
def build_pos_embed(self, device=None):
if self.pos_emb_cls == "rope3d":
cls_type = VideoRopePosition3DEmb
else:
raise ValueError(f"Unknown pos_emb_cls {self.pos_emb_cls}")
logging.debug(f"Building positional embedding with {self.pos_emb_cls} class, impl {cls_type}")
kwargs = dict(
model_channels=self.model_channels,
len_h=self.max_img_h // self.patch_spatial,
len_w=self.max_img_w // self.patch_spatial,
len_t=self.max_frames // self.patch_temporal,
is_learnable=self.pos_emb_learnable,
interpolation=self.pos_emb_interpolation,
head_dim=self.model_channels // self.num_heads,
h_extrapolation_ratio=self.rope_h_extrapolation_ratio,
w_extrapolation_ratio=self.rope_w_extrapolation_ratio,
t_extrapolation_ratio=self.rope_t_extrapolation_ratio,
device=device,
)
self.pos_embedder = cls_type(
**kwargs,
)
if self.extra_per_block_abs_pos_emb:
assert self.extra_per_block_abs_pos_emb_type in [
"learnable",
], f"Unknown extra_per_block_abs_pos_emb_type {self.extra_per_block_abs_pos_emb_type}"
kwargs["h_extrapolation_ratio"] = self.extra_h_extrapolation_ratio
kwargs["w_extrapolation_ratio"] = self.extra_w_extrapolation_ratio
kwargs["t_extrapolation_ratio"] = self.extra_t_extrapolation_ratio
kwargs["device"] = device
self.extra_pos_embedder = LearnablePosEmbAxis(
**kwargs,
)
def prepare_embedded_sequence(
self,
x_B_C_T_H_W: torch.Tensor,
fps: Optional[torch.Tensor] = None,
padding_mask: Optional[torch.Tensor] = None,
latent_condition: Optional[torch.Tensor] = None,
latent_condition_sigma: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
Prepares an embedded sequence tensor by applying positional embeddings and handling padding masks.
Args:
x_B_C_T_H_W (torch.Tensor): video
fps (Optional[torch.Tensor]): Frames per second tensor to be used for positional embedding when required.
If None, a default value (`self.base_fps`) will be used.
padding_mask (Optional[torch.Tensor]): current it is not used
Returns:
Tuple[torch.Tensor, Optional[torch.Tensor]]:
- A tensor of shape (B, T, H, W, D) with the embedded sequence.
- An optional positional embedding tensor, returned only if the positional embedding class
(`self.pos_emb_cls`) includes 'rope'. Otherwise, None.
Notes:
- If `self.concat_padding_mask` is True, a padding mask channel is concatenated to the input tensor.
- The method of applying positional embeddings depends on the value of `self.pos_emb_cls`.
- If 'rope' is in `self.pos_emb_cls` (case insensitive), the positional embeddings are generated using
the `self.pos_embedder` with the shape [T, H, W].
- If "fps_aware" is in `self.pos_emb_cls`, the positional embeddings are generated using the
`self.pos_embedder` with the fps tensor.
- Otherwise, the positional embeddings are generated without considering fps.
"""
if self.concat_padding_mask:
if padding_mask is not None:
padding_mask = transforms.functional.resize(
padding_mask, list(x_B_C_T_H_W.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST
)
else:
padding_mask = torch.zeros((x_B_C_T_H_W.shape[0], 1, x_B_C_T_H_W.shape[-2], x_B_C_T_H_W.shape[-1]), dtype=x_B_C_T_H_W.dtype, device=x_B_C_T_H_W.device)
x_B_C_T_H_W = torch.cat(
[x_B_C_T_H_W, padding_mask.unsqueeze(1).repeat(1, 1, x_B_C_T_H_W.shape[2], 1, 1)], dim=1
)
x_B_T_H_W_D = self.x_embedder(x_B_C_T_H_W)
if self.extra_per_block_abs_pos_emb:
extra_pos_emb = self.extra_pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device)
else:
extra_pos_emb = None
if "rope" in self.pos_emb_cls.lower():
return x_B_T_H_W_D, self.pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device), extra_pos_emb
if "fps_aware" in self.pos_emb_cls:
x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device) # [B, T, H, W, D]
else:
x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D, device=x_B_C_T_H_W.device) # [B, T, H, W, D]
return x_B_T_H_W_D, None, extra_pos_emb
def decoder_head(
self,
x_B_T_H_W_D: torch.Tensor,
emb_B_D: torch.Tensor,
crossattn_emb: torch.Tensor,
origin_shape: Tuple[int, int, int, int, int], # [B, C, T, H, W]
crossattn_mask: Optional[torch.Tensor] = None,
adaln_lora_B_3D: Optional[torch.Tensor] = None,
) -> torch.Tensor:
del crossattn_emb, crossattn_mask
B, C, T_before_patchify, H_before_patchify, W_before_patchify = origin_shape
x_BT_HW_D = rearrange(x_B_T_H_W_D, "B T H W D -> (B T) (H W) D")
x_BT_HW_D = self.final_layer(x_BT_HW_D, emb_B_D, adaln_lora_B_3D=adaln_lora_B_3D)
# This is to ensure x_BT_HW_D has the correct shape because
# when we merge T, H, W into one dimension, x_BT_HW_D has shape (B * T * H * W, 1*1, D).
x_BT_HW_D = x_BT_HW_D.view(
B * T_before_patchify // self.patch_temporal,
H_before_patchify // self.patch_spatial * W_before_patchify // self.patch_spatial,
-1,
)
x_B_D_T_H_W = rearrange(
x_BT_HW_D,
"(B T) (H W) (p1 p2 t C) -> B C (T t) (H p1) (W p2)",
p1=self.patch_spatial,
p2=self.patch_spatial,
H=H_before_patchify // self.patch_spatial,
W=W_before_patchify // self.patch_spatial,
t=self.patch_temporal,
B=B,
)
return x_B_D_T_H_W
def forward_before_blocks(
self,
x: torch.Tensor,
timesteps: torch.Tensor,
crossattn_emb: torch.Tensor,
crossattn_mask: Optional[torch.Tensor] = None,
fps: Optional[torch.Tensor] = None,
image_size: Optional[torch.Tensor] = None,
padding_mask: Optional[torch.Tensor] = None,
scalar_feature: Optional[torch.Tensor] = None,
data_type: Optional[DataType] = DataType.VIDEO,
latent_condition: Optional[torch.Tensor] = None,
latent_condition_sigma: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
"""
Args:
x: (B, C, T, H, W) tensor of spatial-temp inputs
timesteps: (B, ) tensor of timesteps
crossattn_emb: (B, N, D) tensor of cross-attention embeddings
crossattn_mask: (B, N) tensor of cross-attention masks
"""
del kwargs
assert isinstance(
data_type, DataType
), f"Expected DataType, got {type(data_type)}. We need discuss this flag later."
original_shape = x.shape
x_B_T_H_W_D, rope_emb_L_1_1_D, extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = self.prepare_embedded_sequence(
x,
fps=fps,
padding_mask=padding_mask,
latent_condition=latent_condition,
latent_condition_sigma=latent_condition_sigma,
)
# logging affline scale information
affline_scale_log_info = {}
timesteps_B_D, adaln_lora_B_3D = self.t_embedder[1](self.t_embedder[0](timesteps.flatten()).to(x.dtype))
affline_emb_B_D = timesteps_B_D
affline_scale_log_info["timesteps_B_D"] = timesteps_B_D.detach()
if scalar_feature is not None:
raise NotImplementedError("Scalar feature is not implemented yet.")
affline_scale_log_info["affline_emb_B_D"] = affline_emb_B_D.detach()
affline_emb_B_D = self.affline_norm(affline_emb_B_D)
if self.use_cross_attn_mask:
if crossattn_mask is not None and not torch.is_floating_point(crossattn_mask):
crossattn_mask = (crossattn_mask - 1).to(x.dtype) * torch.finfo(x.dtype).max
crossattn_mask = crossattn_mask[:, None, None, :] # .to(dtype=torch.bool) # [B, 1, 1, length]
else:
crossattn_mask = None
if self.blocks["block0"].x_format == "THWBD":
x = rearrange(x_B_T_H_W_D, "B T H W D -> T H W B D")
if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None:
extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = rearrange(
extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, "B T H W D -> T H W B D"
)
crossattn_emb = rearrange(crossattn_emb, "B M D -> M B D")
if crossattn_mask:
crossattn_mask = rearrange(crossattn_mask, "B M -> M B")
elif self.blocks["block0"].x_format == "BTHWD":
x = x_B_T_H_W_D
else:
raise ValueError(f"Unknown x_format {self.blocks[0].x_format}")
output = {
"x": x,
"affline_emb_B_D": affline_emb_B_D,
"crossattn_emb": crossattn_emb,
"crossattn_mask": crossattn_mask,
"rope_emb_L_1_1_D": rope_emb_L_1_1_D,
"adaln_lora_B_3D": adaln_lora_B_3D,
"original_shape": original_shape,
"extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D,
}
return output
def forward(
self,
x: torch.Tensor,
timesteps: torch.Tensor,
context: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
# crossattn_emb: torch.Tensor,
# crossattn_mask: Optional[torch.Tensor] = None,
fps: Optional[torch.Tensor] = None,
image_size: Optional[torch.Tensor] = None,
padding_mask: Optional[torch.Tensor] = None,
scalar_feature: Optional[torch.Tensor] = None,
data_type: Optional[DataType] = DataType.VIDEO,
latent_condition: Optional[torch.Tensor] = None,
latent_condition_sigma: Optional[torch.Tensor] = None,
condition_video_augment_sigma: Optional[torch.Tensor] = None,
**kwargs,
):
"""
Args:
x: (B, C, T, H, W) tensor of spatial-temp inputs
timesteps: (B, ) tensor of timesteps
crossattn_emb: (B, N, D) tensor of cross-attention embeddings
crossattn_mask: (B, N) tensor of cross-attention masks
condition_video_augment_sigma: (B,) used in lvg(long video generation), we add noise with this sigma to
augment condition input, the lvg model will condition on the condition_video_augment_sigma value;
we need forward_before_blocks pass to the forward_before_blocks function.
"""
crossattn_emb = context
crossattn_mask = attention_mask
inputs = self.forward_before_blocks(
x=x,
timesteps=timesteps,
crossattn_emb=crossattn_emb,
crossattn_mask=crossattn_mask,
fps=fps,
image_size=image_size,
padding_mask=padding_mask,
scalar_feature=scalar_feature,
data_type=data_type,
latent_condition=latent_condition,
latent_condition_sigma=latent_condition_sigma,
condition_video_augment_sigma=condition_video_augment_sigma,
**kwargs,
)
x, affline_emb_B_D, crossattn_emb, crossattn_mask, rope_emb_L_1_1_D, adaln_lora_B_3D, original_shape = (
inputs["x"],
inputs["affline_emb_B_D"],
inputs["crossattn_emb"],
inputs["crossattn_mask"],
inputs["rope_emb_L_1_1_D"],
inputs["adaln_lora_B_3D"],
inputs["original_shape"],
)
extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = inputs["extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D"].to(x.dtype)
if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None:
assert (
x.shape == extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape
), f"{x.shape} != {extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape} {original_shape}"
for _, block in self.blocks.items():
assert (
self.blocks["block0"].x_format == block.x_format
), f"First block has x_format {self.blocks[0].x_format}, got {block.x_format}"
x = block(
x,
affline_emb_B_D,
crossattn_emb,
crossattn_mask,
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
adaln_lora_B_3D=adaln_lora_B_3D,
extra_per_block_pos_emb=extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D,
)
x_B_T_H_W_D = rearrange(x, "T H W B D -> B T H W D")
x_B_D_T_H_W = self.decoder_head(
x_B_T_H_W_D=x_B_T_H_W_D,
emb_B_D=affline_emb_B_D,
crossattn_emb=None,
origin_shape=original_shape,
crossattn_mask=None,
adaln_lora_B_3D=adaln_lora_B_3D,
)
return x_B_D_T_H_W

View File

@ -0,0 +1,207 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional
import torch
from einops import rearrange, repeat
from torch import nn
import math
def normalize(x: torch.Tensor, dim: Optional[List[int]] = None, eps: float = 0) -> torch.Tensor:
"""
Normalizes the input tensor along specified dimensions such that the average square norm of elements is adjusted.
Args:
x (torch.Tensor): The input tensor to normalize.
dim (list, optional): The dimensions over which to normalize. If None, normalizes over all dimensions except the first.
eps (float, optional): A small constant to ensure numerical stability during division.
Returns:
torch.Tensor: The normalized tensor.
"""
if dim is None:
dim = list(range(1, x.ndim))
norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32)
norm = torch.add(eps, norm, alpha=math.sqrt(norm.numel() / x.numel()))
return x / norm.to(x.dtype)
class VideoPositionEmb(nn.Module):
def forward(self, x_B_T_H_W_C: torch.Tensor, fps=Optional[torch.Tensor], device=None) -> torch.Tensor:
"""
It delegates the embedding generation to generate_embeddings function.
"""
B_T_H_W_C = x_B_T_H_W_C.shape
embeddings = self.generate_embeddings(B_T_H_W_C, fps=fps, device=device)
return embeddings
def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor], device=None):
raise NotImplementedError
class VideoRopePosition3DEmb(VideoPositionEmb):
def __init__(
self,
*, # enforce keyword arguments
head_dim: int,
len_h: int,
len_w: int,
len_t: int,
base_fps: int = 24,
h_extrapolation_ratio: float = 1.0,
w_extrapolation_ratio: float = 1.0,
t_extrapolation_ratio: float = 1.0,
device=None,
**kwargs, # used for compatibility with other positional embeddings; unused in this class
):
del kwargs
super().__init__()
self.register_buffer("seq", torch.arange(max(len_h, len_w, len_t), dtype=torch.float, device=device))
self.base_fps = base_fps
self.max_h = len_h
self.max_w = len_w
dim = head_dim
dim_h = dim // 6 * 2
dim_w = dim_h
dim_t = dim - 2 * dim_h
assert dim == dim_h + dim_w + dim_t, f"bad dim: {dim} != {dim_h} + {dim_w} + {dim_t}"
self.register_buffer(
"dim_spatial_range",
torch.arange(0, dim_h, 2, device=device)[: (dim_h // 2)].float() / dim_h,
persistent=False,
)
self.register_buffer(
"dim_temporal_range",
torch.arange(0, dim_t, 2, device=device)[: (dim_t // 2)].float() / dim_t,
persistent=False,
)
self.h_ntk_factor = h_extrapolation_ratio ** (dim_h / (dim_h - 2))
self.w_ntk_factor = w_extrapolation_ratio ** (dim_w / (dim_w - 2))
self.t_ntk_factor = t_extrapolation_ratio ** (dim_t / (dim_t - 2))
def generate_embeddings(
self,
B_T_H_W_C: torch.Size,
fps: Optional[torch.Tensor] = None,
h_ntk_factor: Optional[float] = None,
w_ntk_factor: Optional[float] = None,
t_ntk_factor: Optional[float] = None,
device=None,
):
"""
Generate embeddings for the given input size.
Args:
B_T_H_W_C (torch.Size): Input tensor size (Batch, Time, Height, Width, Channels).
fps (Optional[torch.Tensor], optional): Frames per second. Defaults to None.
h_ntk_factor (Optional[float], optional): Height NTK factor. If None, uses self.h_ntk_factor.
w_ntk_factor (Optional[float], optional): Width NTK factor. If None, uses self.w_ntk_factor.
t_ntk_factor (Optional[float], optional): Time NTK factor. If None, uses self.t_ntk_factor.
Returns:
Not specified in the original code snippet.
"""
h_ntk_factor = h_ntk_factor if h_ntk_factor is not None else self.h_ntk_factor
w_ntk_factor = w_ntk_factor if w_ntk_factor is not None else self.w_ntk_factor
t_ntk_factor = t_ntk_factor if t_ntk_factor is not None else self.t_ntk_factor
h_theta = 10000.0 * h_ntk_factor
w_theta = 10000.0 * w_ntk_factor
t_theta = 10000.0 * t_ntk_factor
h_spatial_freqs = 1.0 / (h_theta**self.dim_spatial_range.to(device=device))
w_spatial_freqs = 1.0 / (w_theta**self.dim_spatial_range.to(device=device))
temporal_freqs = 1.0 / (t_theta**self.dim_temporal_range.to(device=device))
B, T, H, W, _ = B_T_H_W_C
uniform_fps = (fps is None) or isinstance(fps, (int, float)) or (fps.min() == fps.max())
assert (
uniform_fps or B == 1 or T == 1
), "For video batch, batch size should be 1 for non-uniform fps. For image batch, T should be 1"
assert (
H <= self.max_h and W <= self.max_w
), f"Input dimensions (H={H}, W={W}) exceed the maximum dimensions (max_h={self.max_h}, max_w={self.max_w})"
half_emb_h = torch.outer(self.seq[:H].to(device=device), h_spatial_freqs)
half_emb_w = torch.outer(self.seq[:W].to(device=device), w_spatial_freqs)
# apply sequence scaling in temporal dimension
if fps is None: # image case
half_emb_t = torch.outer(self.seq[:T].to(device=device), temporal_freqs)
else:
half_emb_t = torch.outer(self.seq[:T].to(device=device) / fps * self.base_fps, temporal_freqs)
half_emb_h = torch.stack([torch.cos(half_emb_h), -torch.sin(half_emb_h), torch.sin(half_emb_h), torch.cos(half_emb_h)], dim=-1)
half_emb_w = torch.stack([torch.cos(half_emb_w), -torch.sin(half_emb_w), torch.sin(half_emb_w), torch.cos(half_emb_w)], dim=-1)
half_emb_t = torch.stack([torch.cos(half_emb_t), -torch.sin(half_emb_t), torch.sin(half_emb_t), torch.cos(half_emb_t)], dim=-1)
em_T_H_W_D = torch.cat(
[
repeat(half_emb_t, "t d x -> t h w d x", h=H, w=W),
repeat(half_emb_h, "h d x -> t h w d x", t=T, w=W),
repeat(half_emb_w, "w d x -> t h w d x", t=T, h=H),
]
, dim=-2,
)
return rearrange(em_T_H_W_D, "t h w d (i j) -> (t h w) d i j", i=2, j=2).float()
class LearnablePosEmbAxis(VideoPositionEmb):
def __init__(
self,
*, # enforce keyword arguments
interpolation: str,
model_channels: int,
len_h: int,
len_w: int,
len_t: int,
device=None,
**kwargs,
):
"""
Args:
interpolation (str): we curretly only support "crop", ideally when we need extrapolation capacity, we should adjust frequency or other more advanced methods. they are not implemented yet.
"""
del kwargs # unused
super().__init__()
self.interpolation = interpolation
assert self.interpolation in ["crop"], f"Unknown interpolation method {self.interpolation}"
self.pos_emb_h = nn.Parameter(torch.empty(len_h, model_channels, device=device))
self.pos_emb_w = nn.Parameter(torch.empty(len_w, model_channels, device=device))
self.pos_emb_t = nn.Parameter(torch.empty(len_t, model_channels, device=device))
def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor], device=None) -> torch.Tensor:
B, T, H, W, _ = B_T_H_W_C
if self.interpolation == "crop":
emb_h_H = self.pos_emb_h[:H].to(device=device)
emb_w_W = self.pos_emb_w[:W].to(device=device)
emb_t_T = self.pos_emb_t[:T].to(device=device)
emb = (
repeat(emb_t_T, "t d-> b t h w d", b=B, h=H, w=W)
+ repeat(emb_h_H, "h d-> b t h w d", b=B, t=T, w=W)
+ repeat(emb_w_W, "w d-> b t h w d", b=B, t=T, h=H)
)
assert list(emb.shape)[:4] == [B, T, H, W], f"bad shape: {list(emb.shape)[:4]} != {B, T, H, W}"
else:
raise ValueError(f"Unknown interpolation method {self.interpolation}")
return normalize(emb, dim=-1, eps=1e-6)

124
comfy/ldm/cosmos/vae.py Normal file
View File

@ -0,0 +1,124 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The causal continuous video tokenizer with VAE or AE formulation for 3D data.."""
import logging
import torch
from torch import nn
from enum import Enum
from .cosmos_tokenizer.layers3d import (
EncoderFactorized,
DecoderFactorized,
CausalConv3d,
)
class IdentityDistribution(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, parameters):
return parameters, (torch.tensor([0.0]), torch.tensor([0.0]))
class GaussianDistribution(torch.nn.Module):
def __init__(self, min_logvar: float = -30.0, max_logvar: float = 20.0):
super().__init__()
self.min_logvar = min_logvar
self.max_logvar = max_logvar
def sample(self, mean, logvar):
std = torch.exp(0.5 * logvar)
return mean + std * torch.randn_like(mean)
def forward(self, parameters):
mean, logvar = torch.chunk(parameters, 2, dim=1)
logvar = torch.clamp(logvar, self.min_logvar, self.max_logvar)
return self.sample(mean, logvar), (mean, logvar)
class ContinuousFormulation(Enum):
VAE = GaussianDistribution
AE = IdentityDistribution
class CausalContinuousVideoTokenizer(nn.Module):
def __init__(
self, z_channels: int, z_factor: int, latent_channels: int, **kwargs
) -> None:
super().__init__()
self.name = kwargs.get("name", "CausalContinuousVideoTokenizer")
self.latent_channels = latent_channels
self.sigma_data = 0.5
# encoder_name = kwargs.get("encoder", Encoder3DType.BASE.name)
self.encoder = EncoderFactorized(
z_channels=z_factor * z_channels, **kwargs
)
if kwargs.get("temporal_compression", 4) == 4:
kwargs["channels_mult"] = [2, 4]
# decoder_name = kwargs.get("decoder", Decoder3DType.BASE.name)
self.decoder = DecoderFactorized(
z_channels=z_channels, **kwargs
)
self.quant_conv = CausalConv3d(
z_factor * z_channels,
z_factor * latent_channels,
kernel_size=1,
padding=0,
)
self.post_quant_conv = CausalConv3d(
latent_channels, z_channels, kernel_size=1, padding=0
)
# formulation_name = kwargs.get("formulation", ContinuousFormulation.AE.name)
self.distribution = IdentityDistribution() # ContinuousFormulation[formulation_name].value()
num_parameters = sum(param.numel() for param in self.parameters())
logging.info(f"model={self.name}, num_parameters={num_parameters:,}")
logging.info(
f"z_channels={z_channels}, latent_channels={self.latent_channels}."
)
latent_temporal_chunk = 16
self.latent_mean = nn.Parameter(torch.zeros([self.latent_channels * latent_temporal_chunk], dtype=torch.float32))
self.latent_std = nn.Parameter(torch.ones([self.latent_channels * latent_temporal_chunk], dtype=torch.float32))
def encode(self, x):
h = self.encoder(x)
moments = self.quant_conv(h)
z, posteriors = self.distribution(moments)
latent_ch = z.shape[1]
latent_t = z.shape[2]
dtype = z.dtype
mean = self.latent_mean.view(latent_ch, -1)[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=dtype, device=z.device)
std = self.latent_std.view(latent_ch, -1)[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=dtype, device=z.device)
return ((z - mean) / std) * self.sigma_data
def decode(self, z):
in_dtype = z.dtype
latent_ch = z.shape[1]
latent_t = z.shape[2]
mean = self.latent_mean.view(latent_ch, -1)[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=in_dtype, device=z.device)
std = self.latent_std.view(latent_ch, -1)[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=in_dtype, device=z.device)
z = z / self.sigma_data
z = z * std + mean
z = self.post_quant_conv(z)
return self.decoder(z)

View File

@ -89,7 +89,7 @@ class FeedForward(nn.Module):
def Normalize(in_channels, dtype=None, device=None):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
attn_precision = get_attn_precision(attn_precision)
if skip_reshape:
@ -142,6 +142,13 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
sim = sim.softmax(dim=-1)
out = einsum('b i j, b j d -> b i d', sim.to(v.dtype), v)
if skip_output_reshape:
out = (
out.unsqueeze(0)
.reshape(b, heads, -1, dim_head)
)
else:
out = (
out.unsqueeze(0)
.reshape(b, heads, -1, dim_head)
@ -151,7 +158,7 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
return out
def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False):
def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
attn_precision = get_attn_precision(attn_precision)
if skip_reshape:
@ -215,11 +222,13 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None,
)
hidden_states = hidden_states.to(dtype)
if skip_output_reshape:
hidden_states = hidden_states.unflatten(0, (-1, heads))
else:
hidden_states = hidden_states.unflatten(0, (-1, heads)).transpose(1,2).flatten(start_dim=2)
return hidden_states
def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
attn_precision = get_attn_precision(attn_precision)
if skip_reshape:
@ -326,6 +335,12 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
del q, k, v
if skip_output_reshape:
r1 = (
r1.unsqueeze(0)
.reshape(b, heads, -1, dim_head)
)
else:
r1 = (
r1.unsqueeze(0)
.reshape(b, heads, -1, dim_head)
@ -342,7 +357,7 @@ try:
except:
pass
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
b = q.shape[0]
dim_head = q.shape[-1]
# check to make sure xformers isn't broken
@ -395,6 +410,9 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask)
if skip_output_reshape:
out = out.permute(0, 2, 1, 3)
else:
out = (
out.reshape(b, -1, heads * dim_head)
)
@ -408,7 +426,7 @@ else:
SDP_BATCH_LIMIT = 2**31
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
if skip_reshape:
b, _, _, dim_head = q.shape
else:
@ -429,6 +447,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
if SDP_BATCH_LIMIT >= b:
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
if not skip_output_reshape:
out = (
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
)
@ -450,7 +469,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
return out
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
if skip_reshape:
b, _, _, dim_head = q.shape
tensor_layout="HND"
@ -473,9 +492,13 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout)
if tensor_layout == "HND":
if not skip_output_reshape:
out = (
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
)
else:
if skip_output_reshape:
out = out.transpose(1, 2)
else:
out = out.reshape(b, -1, heads * dim_head)
return out

View File

@ -33,6 +33,7 @@ import comfy.ldm.audio.embedders
import comfy.ldm.flux.model
import comfy.ldm.lightricks.model
import comfy.ldm.hunyuan_video.model
import comfy.ldm.cosmos.model
import comfy.model_management
import comfy.patcher_extension
@ -856,3 +857,19 @@ class HunyuanVideo(BaseModel):
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([kwargs.get("guidance", 6.0)]))
return out
class CosmosVideo(BaseModel):
def __init__(self, model_config, model_type=ModelType.EDM, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.cosmos.model.GeneralDIT)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
attention_mask = kwargs.get("attention_mask", None)
if attention_mask is not None:
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
out['fps'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", None))
return out

View File

@ -239,6 +239,50 @@ def detect_unet_config(state_dict, key_prefix):
dit_config["micro_condition"] = False
return dit_config
if '{}blocks.block0.blocks.0.block.attn.to_q.0.weight'.format(key_prefix) in state_dict_keys:
dit_config = {}
dit_config["image_model"] = "cosmos"
dit_config["max_img_h"] = 240
dit_config["max_img_w"] = 240
dit_config["max_frames"] = 128
dit_config["in_channels"] = 16
dit_config["out_channels"] = 16
dit_config["patch_spatial"] = 2
dit_config["patch_temporal"] = 1
dit_config["model_channels"] = state_dict['{}blocks.block0.blocks.0.block.attn.to_q.0.weight'.format(key_prefix)].shape[0]
dit_config["block_config"] = "FA-CA-MLP"
dit_config["concat_padding_mask"] = True
dit_config["pos_emb_cls"] = "rope3d"
dit_config["pos_emb_learnable"] = False
dit_config["pos_emb_interpolation"] = "crop"
dit_config["block_x_format"] = "THWBD"
dit_config["affline_emb_norm"] = True
dit_config["use_adaln_lora"] = True
dit_config["adaln_lora_dim"] = 256
if dit_config["model_channels"] == 4096:
# 7B
dit_config["num_blocks"] = 28
dit_config["num_heads"] = 32
dit_config["extra_per_block_abs_pos_emb"] = True
dit_config["rope_h_extrapolation_ratio"] = 1.0
dit_config["rope_w_extrapolation_ratio"] = 1.0
dit_config["rope_t_extrapolation_ratio"] = 2.0
dit_config["extra_per_block_abs_pos_emb_type"] = "learnable"
else: # 5120
# 14B
dit_config["num_blocks"] = 36
dit_config["num_heads"] = 40
dit_config["extra_per_block_abs_pos_emb"] = True
dit_config["rope_h_extrapolation_ratio"] = 2.0
dit_config["rope_w_extrapolation_ratio"] = 2.0
dit_config["rope_t_extrapolation_ratio"] = 2.0
dit_config["extra_h_extrapolation_ratio"] = 2.0
dit_config["extra_w_extrapolation_ratio"] = 2.0
dit_config["extra_t_extrapolation_ratio"] = 2.0
dit_config["extra_per_block_abs_pos_emb_type"] = "learnable"
return dit_config
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
return None
@ -393,6 +437,7 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
def unet_prefix_from_state_dict(state_dict):
candidates = ["model.diffusion_model.", #ldm/sgm models
"model.model.", #audio models
"net.", #cosmos
]
counts = {k: 0 for k in candidates}
for k in state_dict:

View File

@ -210,7 +210,7 @@ class ModelPatcher:
self.injections: dict[str, list[PatcherInjection]] = {}
self.hook_patches: dict[comfy.hooks._HookRef] = {}
self.hook_patches_backup: dict[comfy.hooks._HookRef] = {}
self.hook_patches_backup: dict[comfy.hooks._HookRef] = None
self.hook_backup: dict[str, tuple[torch.Tensor, torch.device]] = {}
self.cached_hook_patches: dict[comfy.hooks.HookGroup, dict[str, torch.Tensor]] = {}
self.current_hooks: Optional[comfy.hooks.HookGroup] = None
@ -282,7 +282,7 @@ class ModelPatcher:
n.injections[k] = i.copy()
# hooks
n.hook_patches = create_hook_patches_clone(self.hook_patches)
n.hook_patches_backup = create_hook_patches_clone(self.hook_patches_backup)
n.hook_patches_backup = create_hook_patches_clone(self.hook_patches_backup) if self.hook_patches_backup else self.hook_patches_backup
for group in self.cached_hook_patches:
n.cached_hook_patches[group] = {}
for k in self.cached_hook_patches[group]:
@ -855,6 +855,9 @@ class ModelPatcher:
if key in self.injections:
self.injections.pop(key)
def get_injections(self, key: str):
return self.injections.get(key, None)
def set_additional_models(self, key: str, models: list['ModelPatcher']):
self.additional_models[key] = models
@ -925,9 +928,9 @@ class ModelPatcher:
callback(self, timestep)
def restore_hook_patches(self):
if len(self.hook_patches_backup) > 0:
if self.hook_patches_backup is not None:
self.hook_patches = self.hook_patches_backup
self.hook_patches_backup = {}
self.hook_patches_backup = None
def set_hook_mode(self, hook_mode: comfy.hooks.EnumHookMode):
self.hook_mode = hook_mode
@ -953,25 +956,26 @@ class ModelPatcher:
if reset_current_hooks:
self.patch_hooks(None)
def register_all_hook_patches(self, hooks_dict: dict[comfy.hooks.EnumHookType, dict[comfy.hooks.Hook, None]], target: comfy.hooks.EnumWeightTarget, model_options: dict=None):
def register_all_hook_patches(self, hooks: comfy.hooks.HookGroup, target_dict: dict[str], model_options: dict=None,
registered: comfy.hooks.HookGroup = None):
self.restore_hook_patches()
registered_hooks: list[comfy.hooks.Hook] = []
# handle WrapperHooks, if model_options provided
if model_options is not None:
for hook in hooks_dict.get(comfy.hooks.EnumHookType.Wrappers, {}):
hook.add_hook_patches(self, model_options, target, registered_hooks)
if registered is None:
registered = comfy.hooks.HookGroup()
# handle WeightHooks
weight_hooks_to_register: list[comfy.hooks.WeightHook] = []
for hook in hooks_dict.get(comfy.hooks.EnumHookType.Weight, {}):
for hook in hooks.get_type(comfy.hooks.EnumHookType.Weight):
if hook.hook_ref not in self.hook_patches:
weight_hooks_to_register.append(hook)
else:
registered.add(hook)
if len(weight_hooks_to_register) > 0:
# clone hook_patches to become backup so that any non-dynamic hooks will return to their original state
self.hook_patches_backup = create_hook_patches_clone(self.hook_patches)
for hook in weight_hooks_to_register:
hook.add_hook_patches(self, model_options, target, registered_hooks)
hook.add_hook_patches(self, model_options, target_dict, registered)
for callback in self.get_all_callbacks(CallbacksMP.ON_REGISTER_ALL_HOOK_PATCHES):
callback(self, hooks_dict, target)
callback(self, hooks, target_dict, model_options, registered)
return registered
def add_hook_patches(self, hook: comfy.hooks.WeightHook, patches, strength_patch=1.0, strength_model=1.0):
with self.use_ejected():
@ -1022,11 +1026,11 @@ class ModelPatcher:
def apply_hooks(self, hooks: comfy.hooks.HookGroup, transformer_options: dict=None, force_apply=False):
# TODO: return transformer_options dict with any additions from hooks
if self.current_hooks == hooks and (not force_apply or (not self.is_clip and hooks is None)):
return {}
return comfy.hooks.create_transformer_options_from_hooks(self, hooks, transformer_options)
self.patch_hooks(hooks=hooks)
for callback in self.get_all_callbacks(CallbacksMP.ON_APPLY_HOOKS):
callback(self, hooks)
return {}
return comfy.hooks.create_transformer_options_from_hooks(self, hooks, transformer_options)
def patch_hooks(self, hooks: comfy.hooks.HookGroup):
with self.use_ejected():

View File

@ -24,15 +24,13 @@ def get_models_from_cond(cond, model_type):
models += [c[model_type]]
return models
def get_hooks_from_cond(cond, hooks_dict: dict[comfy.hooks.EnumHookType, dict[comfy.hooks.Hook, None]]):
def get_hooks_from_cond(cond, full_hooks: comfy.hooks.HookGroup):
# get hooks from conds, and collect cnets so they can be checked for extra_hooks
cnets: list[ControlBase] = []
for c in cond:
if 'hooks' in c:
for hook in c['hooks'].hooks:
hook: comfy.hooks.Hook
with_type = hooks_dict.setdefault(hook.hook_type, {})
with_type[hook] = None
full_hooks.add(hook)
if 'control' in c:
cnets.append(c['control'])
@ -50,10 +48,9 @@ def get_hooks_from_cond(cond, hooks_dict: dict[comfy.hooks.EnumHookType, dict[co
extra_hooks = comfy.hooks.HookGroup.combine_all_hooks(hooks_list)
if extra_hooks is not None:
for hook in extra_hooks.hooks:
with_type = hooks_dict.setdefault(hook.hook_type, {})
with_type[hook] = None
full_hooks.add(hook)
return hooks_dict
return full_hooks
def convert_cond(cond):
out = []
@ -73,13 +70,11 @@ def get_additional_models(conds, dtype):
cnets: list[ControlBase] = []
gligen = []
add_models = []
hooks: dict[comfy.hooks.EnumHookType, dict[comfy.hooks.Hook, None]] = {}
for k in conds:
cnets += get_models_from_cond(conds[k], "control")
gligen += get_models_from_cond(conds[k], "gligen")
add_models += get_models_from_cond(conds[k], "additional_models")
get_hooks_from_cond(conds[k], hooks)
control_nets = set(cnets)
@ -90,11 +85,20 @@ def get_additional_models(conds, dtype):
inference_memory += m.inference_memory_requirements(dtype)
gligen = [x[1] for x in gligen]
hook_models = [x.model for x in hooks.get(comfy.hooks.EnumHookType.AddModels, {}).keys()]
models = control_models + gligen + add_models + hook_models
models = control_models + gligen + add_models
return models, inference_memory
def get_additional_models_from_model_options(model_options: dict[str]=None):
"""loads additional models from registered AddModels hooks"""
models = []
if model_options is not None and "registered_hooks" in model_options:
registered: comfy.hooks.HookGroup = model_options["registered_hooks"]
for hook in registered.get_type(comfy.hooks.EnumHookType.AdditionalModels):
hook: comfy.hooks.AdditionalModelsHook
models.extend(hook.models)
return models
def cleanup_additional_models(models):
"""cleanup additional models that were loaded"""
for m in models:
@ -102,9 +106,10 @@ def cleanup_additional_models(models):
m.cleanup()
def prepare_sampling(model: 'ModelPatcher', noise_shape, conds):
real_model: 'BaseModel' = None
def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None):
real_model: BaseModel = None
models, inference_memory = get_additional_models(conds, model.model_dtype())
models += get_additional_models_from_model_options(model_options)
models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
memory_required = model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory
minimum_memory_required = model.memory_required([noise_shape[0]] + list(noise_shape[1:])) + inference_memory
@ -123,12 +128,35 @@ def cleanup_models(conds, models):
cleanup_additional_models(set(control_cleanup))
def prepare_model_patcher(model: 'ModelPatcher', conds, model_options: dict):
'''
Registers hooks from conds.
'''
# check for hooks in conds - if not registered, see if can be applied
hooks = {}
hooks = comfy.hooks.HookGroup()
for k in conds:
get_hooks_from_cond(conds[k], hooks)
# add wrappers and callbacks from ModelPatcher to transformer_options
model_options["transformer_options"]["wrappers"] = comfy.patcher_extension.copy_nested_dicts(model.wrappers)
model_options["transformer_options"]["callbacks"] = comfy.patcher_extension.copy_nested_dicts(model.callbacks)
# register hooks on model/model_options
model.register_all_hook_patches(hooks, comfy.hooks.EnumWeightTarget.Model, model_options)
# begin registering hooks
registered = comfy.hooks.HookGroup()
target_dict = comfy.hooks.create_target_dict(comfy.hooks.EnumWeightTarget.Model)
# handle all TransformerOptionsHooks
for hook in hooks.get_type(comfy.hooks.EnumHookType.TransformerOptions):
hook: comfy.hooks.TransformerOptionsHook
hook.add_hook_patches(model, model_options, target_dict, registered)
# handle all AddModelsHooks
for hook in hooks.get_type(comfy.hooks.EnumHookType.AdditionalModels):
hook: comfy.hooks.AdditionalModelsHook
hook.add_hook_patches(model, model_options, target_dict, registered)
# handle all WeightHooks by registering on ModelPatcher
model.register_all_hook_patches(hooks, target_dict, model_options, registered)
# add registered_hooks onto model_options for further reference
if len(registered) > 0:
model_options["registered_hooks"] = registered
# merge original wrappers and callbacks with hooked wrappers and callbacks
to_load_options: dict[str] = model_options.setdefault("to_load_options", {})
for wc_name in ["wrappers", "callbacks"]:
comfy.patcher_extension.merge_nested_dicts(to_load_options.setdefault(wc_name, {}), model_options["transformer_options"][wc_name],
copy_dict1=False)
return to_load_options

View File

@ -810,6 +810,33 @@ def preprocess_conds_hooks(conds: dict[str, list[dict[str]]]):
for cond in conds_to_modify:
cond['hooks'] = hooks
def filter_registered_hooks_on_conds(conds: dict[str, list[dict[str]]], model_options: dict[str]):
'''Modify 'hooks' on conds so that only hooks that were registered remain. Properly accounts for
HookGroups that have the same reference.'''
registered: comfy.hooks.HookGroup = model_options.get('registered_hooks', None)
# if None were registered, make sure all hooks are cleaned from conds
if registered is None:
for k in conds:
for kk in conds[k]:
kk.pop('hooks', None)
return
# find conds that contain hooks to be replaced - group by common HookGroup refs
hook_replacement: dict[comfy.hooks.HookGroup, list[dict]] = {}
for k in conds:
for kk in conds[k]:
hooks: comfy.hooks.HookGroup = kk.get('hooks', None)
if hooks is not None:
if not hooks.is_subset_of(registered):
to_replace = hook_replacement.setdefault(hooks, [])
to_replace.append(kk)
# for each hook to replace, create a new proper HookGroup and assign to all common conds
for hooks, conds_to_modify in hook_replacement.items():
new_hooks = hooks.new_with_common_hooks(registered)
if len(new_hooks) == 0:
new_hooks = None
for kk in conds_to_modify:
kk['hooks'] = new_hooks
def get_total_hook_groups_in_conds(conds: dict[str, list[dict[str]]]):
hooks_set = set()
@ -819,9 +846,58 @@ def get_total_hook_groups_in_conds(conds: dict[str, list[dict[str]]]):
return len(hooks_set)
def cast_to_load_options(model_options: dict[str], device=None, dtype=None):
'''
If any patches from hooks, wrappers, or callbacks have .to to be called, call it.
'''
if model_options is None:
return
to_load_options = model_options.get("to_load_options", None)
if to_load_options is None:
return
casts = []
if device is not None:
casts.append(device)
if dtype is not None:
casts.append(dtype)
# if nothing to apply, do nothing
if len(casts) == 0:
return
# try to call .to on patches
if "patches" in to_load_options:
patches = to_load_options["patches"]
for name in patches:
patch_list = patches[name]
for i in range(len(patch_list)):
if hasattr(patch_list[i], "to"):
for cast in casts:
patch_list[i] = patch_list[i].to(cast)
if "patches_replace" in to_load_options:
patches = to_load_options["patches_replace"]
for name in patches:
patch_list = patches[name]
for k in patch_list:
if hasattr(patch_list[k], "to"):
for cast in casts:
patch_list[k] = patch_list[k].to(cast)
# try to call .to on any wrappers/callbacks
wrappers_and_callbacks = ["wrappers", "callbacks"]
for wc_name in wrappers_and_callbacks:
if wc_name in to_load_options:
wc: dict[str, list] = to_load_options[wc_name]
for wc_dict in wc.values():
for wc_list in wc_dict.values():
for i in range(len(wc_list)):
if hasattr(wc_list[i], "to"):
for cast in casts:
wc_list[i] = wc_list[i].to(cast)
class CFGGuider:
def __init__(self, model_patcher):
self.model_patcher: 'ModelPatcher' = model_patcher
def __init__(self, model_patcher: ModelPatcher):
self.model_patcher = model_patcher
self.model_options = model_patcher.model_options
self.original_conds = {}
self.cfg = 1.0
@ -861,7 +937,7 @@ class CFGGuider:
return self.inner_model.process_latent_out(samples.to(torch.float32))
def outer_sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds)
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options)
device = self.model_patcher.load_device
if denoise_mask is not None:
@ -870,6 +946,7 @@ class CFGGuider:
noise = noise.to(device)
latent_image = latent_image.to(device)
sigmas = sigmas.to(device)
cast_to_load_options(self.model_options, device=device, dtype=self.model_patcher.model_dtype())
try:
self.model_patcher.pre_run()
@ -899,6 +976,7 @@ class CFGGuider:
if get_total_hook_groups_in_conds(self.conds) <= 1:
self.model_patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram
comfy.sampler_helpers.prepare_model_patcher(self.model_patcher, self.conds, self.model_options)
filter_registered_hooks_on_conds(self.conds, self.model_options)
executor = comfy.patcher_extension.WrapperExecutor.new_class_executor(
self.outer_sample,
self,
@ -906,6 +984,7 @@ class CFGGuider:
)
output = executor.execute(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
finally:
cast_to_load_options(self.model_options, device=self.model_patcher.offload_device)
self.model_options = orig_model_options
self.model_patcher.hook_mode = orig_hook_mode
self.model_patcher.restore_hook_patches()

View File

@ -11,6 +11,7 @@ from .ldm.cascade.stage_c_coder import StageC_coder
from .ldm.audio.autoencoder import AudioOobleckVAE
import comfy.ldm.genmo.vae.model
import comfy.ldm.lightricks.vae.causal_video_autoencoder
import comfy.ldm.cosmos.vae
import yaml
import math
@ -34,6 +35,7 @@ import comfy.text_encoders.long_clipl
import comfy.text_encoders.genmo
import comfy.text_encoders.lt
import comfy.text_encoders.hunyuan_video
import comfy.text_encoders.cosmos
import comfy.model_patcher
import comfy.lora
@ -376,6 +378,19 @@ class VAE:
self.memory_used_decode = lambda shape, dtype: (1500 * shape[2] * shape[3] * shape[4] * (4 * 8 * 8)) * model_management.dtype_size(dtype)
self.memory_used_encode = lambda shape, dtype: (900 * max(shape[2], 2) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
elif "decoder.unpatcher3d.wavelets" in sd:
self.upscale_ratio = (lambda a: max(0, a * 8 - 7), 8, 8)
self.upscale_index_formula = (8, 8, 8)
self.downscale_ratio = (lambda a: max(0, math.floor((a + 7) / 8)), 8, 8)
self.downscale_index_formula = (8, 8, 8)
self.latent_dim = 3
self.latent_channels = 16
ddconfig = {'z_channels': 16, 'latent_channels': self.latent_channels, 'z_factor': 1, 'resolution': 1024, 'in_channels': 3, 'out_channels': 3, 'channels': 128, 'channels_mult': [2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [32], 'dropout': 0.0, 'patch_size': 4, 'num_groups': 1, 'temporal_compression': 8, 'spacial_compression': 8}
self.first_stage_model = comfy.ldm.cosmos.vae.CausalContinuousVideoTokenizer(**ddconfig)
#TODO: these values are a bit off because this is not a standard VAE
self.memory_used_decode = lambda shape, dtype: (220 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype)
self.memory_used_encode = lambda shape, dtype: (500 * max(shape[2], 2) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
self.working_dtypes = [torch.bfloat16, torch.float32]
else:
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
self.first_stage_model = None
@ -641,6 +656,7 @@ class CLIPType(Enum):
LTXV = 8
HUNYUAN_VIDEO = 9
PIXART = 10
COSMOS = 11
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
@ -658,6 +674,7 @@ class TEModel(Enum):
T5_XL = 5
T5_BASE = 6
LLAMA3_8 = 7
T5_XXL_OLD = 8
def detect_te_model(sd):
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
@ -672,6 +689,8 @@ def detect_te_model(sd):
return TEModel.T5_XXL
elif weight.shape[-1] == 2048:
return TEModel.T5_XL
if 'encoder.block.23.layer.1.DenseReluDense.wi.weight' in sd:
return TEModel.T5_XXL_OLD
if "encoder.block.0.layer.0.SelfAttention.k.weight" in sd:
return TEModel.T5_BASE
if "model.layers.0.post_attention_layernorm.weight" in sd:
@ -681,9 +700,10 @@ def detect_te_model(sd):
def t5xxl_detect(clip_data):
weight_name = "encoder.block.23.layer.1.DenseReluDense.wi_1.weight"
weight_name_old = "encoder.block.23.layer.1.DenseReluDense.wi.weight"
for sd in clip_data:
if weight_name in sd:
if weight_name in sd or weight_name_old in sd:
return comfy.text_encoders.sd3_clip.t5_xxl_detect(sd)
return {}
@ -740,6 +760,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
else: #CLIPType.MOCHI
clip_target.clip = comfy.text_encoders.genmo.mochi_te(**t5xxl_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.genmo.MochiT5Tokenizer
elif te_model == TEModel.T5_XXL_OLD:
clip_target.clip = comfy.text_encoders.cosmos.te(**t5xxl_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.cosmos.CosmosT5Tokenizer
elif te_model == TEModel.T5_XL:
clip_target.clip = comfy.text_encoders.aura_t5.AuraT5Model
clip_target.tokenizer = comfy.text_encoders.aura_t5.AuraT5Tokenizer

View File

@ -14,6 +14,7 @@ import comfy.text_encoders.flux
import comfy.text_encoders.genmo
import comfy.text_encoders.lt
import comfy.text_encoders.hunyuan_video
import comfy.text_encoders.cosmos
from . import supported_models_base
from . import latent_formats
@ -823,6 +824,37 @@ class HunyuanVideo(supported_models_base.BASE):
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}llama.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_video.HunyuanVideoTokenizer, comfy.text_encoders.hunyuan_video.hunyuan_video_clip(**hunyuan_detect))
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo]
class Cosmos(supported_models_base.BASE):
unet_config = {
"image_model": "cosmos",
}
sampling_settings = {
"sigma_data": 0.5,
"sigma_max": 80.0,
"sigma_min": 0.002,
}
unet_extra_config = {}
latent_format = latent_formats.Cosmos1CV8x8x8
memory_usage_factor = 2.4 #TODO
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32] #TODO
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
def get_model(self, state_dict, prefix="", device=None):
out = model_base.CosmosVideo(self, device=device)
return out
def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.cosmos.CosmosT5Tokenizer, comfy.text_encoders.cosmos.te(**t5_detect))
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo, Cosmos]
models += [SVD_img2vid]

View File

@ -0,0 +1,42 @@
from comfy import sd1_clip
import comfy.text_encoders.t5
import os
from transformers import T5TokenizerFast
class T5XXLModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_old_config_xxl.json")
t5xxl_scaled_fp8 = model_options.get("t5xxl_scaled_fp8", None)
if t5xxl_scaled_fp8 is not None:
model_options = model_options.copy()
model_options["scaled_fp8"] = t5xxl_scaled_fp8
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, zero_out_masked=attention_mask, model_options=model_options)
class CosmosT5XXL(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, dtype=dtype, name="t5xxl", clip_model=T5XXLModel, model_options=model_options)
class T5XXLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=1024, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=512)
class CosmosT5Tokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
def te(dtype_t5=None, t5xxl_scaled_fp8=None):
class CosmosTEModel_(CosmosT5XXL):
def __init__(self, device="cpu", dtype=None, model_options={}):
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
model_options = model_options.copy()
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
if dtype is None:
dtype = dtype_t5
super().__init__(device=device, dtype=dtype, model_options=model_options)
return CosmosTEModel_

View File

@ -0,0 +1,22 @@
{
"d_ff": 65536,
"d_kv": 128,
"d_model": 1024,
"decoder_start_token_id": 0,
"dropout_rate": 0.1,
"eos_token_id": 1,
"dense_act_fn": "relu",
"initializer_factor": 1.0,
"is_encoder_decoder": true,
"is_gated_act": false,
"layer_norm_epsilon": 1e-06,
"model_type": "t5",
"num_decoder_layers": 24,
"num_heads": 128,
"num_layers": 24,
"output_past": true,
"pad_token_id": 0,
"relative_attention_num_buckets": 32,
"tie_word_embeddings": false,
"vocab_size": 32128
}

View File

@ -0,0 +1,23 @@
import nodes
import torch
import comfy.model_management
class EmptyCosmosLatentVideo:
@classmethod
def INPUT_TYPES(s):
return {"required": { "width": ("INT", {"default": 1280, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"height": ("INT", {"default": 704, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"length": ("INT", {"default": 121, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 8}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "generate"
CATEGORY = "latent/video"
def generate(self, width, height, length, batch_size=1):
latent = torch.zeros([batch_size, 16, ((length - 1) // 8) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
return ({"samples":latent}, )
NODE_CLASS_MAPPINGS = {
"EmptyCosmosLatentVideo": EmptyCosmosLatentVideo,
}

View File

@ -246,7 +246,7 @@ class SetClipHooks:
CATEGORY = "advanced/hooks/clip"
FUNCTION = "apply_hooks"
def apply_hooks(self, clip: 'CLIP', schedule_clip: bool, apply_to_conds: bool, hooks: comfy.hooks.HookGroup=None):
def apply_hooks(self, clip: CLIP, schedule_clip: bool, apply_to_conds: bool, hooks: comfy.hooks.HookGroup=None):
if hooks is not None:
clip = clip.clone()
if apply_to_conds:
@ -255,7 +255,7 @@ class SetClipHooks:
clip.use_clip_schedule = schedule_clip
if not clip.use_clip_schedule:
clip.patcher.forced_hooks.set_keyframes_on_hooks(None)
clip.patcher.register_all_hook_patches(hooks.get_dict_repr(), comfy.hooks.EnumWeightTarget.Clip)
clip.patcher.register_all_hook_patches(hooks, comfy.hooks.create_target_dict(comfy.hooks.EnumWeightTarget.Clip))
return (clip,)
class ConditioningTimestepsRange:

View File

@ -912,7 +912,7 @@ class CLIPLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart"], ),
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos"], ),
},
"optional": {
"device": (["default", "cpu"], {"advanced": True}),
@ -922,7 +922,7 @@ class CLIPLoader:
CATEGORY = "advanced/loaders"
DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 / clip-g / clip-l\nstable_audio: t5\nmochi: t5"
DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 / clip-g / clip-l\nstable_audio: t5\nmochi: t5\ncosmos: old t5 xxl"
def load_clip(self, clip_name, type="stable_diffusion", device="default"):
if type == "stable_cascade":
@ -2225,6 +2225,7 @@ def init_builtin_extra_nodes():
"nodes_lt.py",
"nodes_hooks.py",
"nodes_load_3d.py",
"nodes_cosmos.py",
]
import_failed = []