mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-07-20 14:37:06 +08:00
get_mdulations added from blepping and minor changes
This commit is contained in:
parent
9f70cfbc42
commit
f04b502ab6
@ -39,6 +39,16 @@ class ChromaParams:
|
|||||||
n_layers: int
|
n_layers: int
|
||||||
|
|
||||||
|
|
||||||
|
class ChromaModulationOut(ModulationOut):
|
||||||
|
@classmethod
|
||||||
|
def from_offset(cls, tensor: torch.Tensor, offset: int = 0) -> ModulationOut:
|
||||||
|
return cls(
|
||||||
|
shift=tensor[:, offset : offset + 1, :],
|
||||||
|
scale=tensor[:, offset + 1 : offset + 2, :],
|
||||||
|
gate=tensor[:, offset + 2 : offset + 3, :],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Chroma(nn.Module):
|
class Chroma(nn.Module):
|
||||||
"""
|
"""
|
||||||
Transformer model for flow matching on sequences.
|
Transformer model for flow matching on sequences.
|
||||||
@ -108,118 +118,34 @@ class Chroma(nn.Module):
|
|||||||
self.skip_mmdit = []
|
self.skip_mmdit = []
|
||||||
self.skip_dit = []
|
self.skip_dit = []
|
||||||
self.lite = False
|
self.lite = False
|
||||||
@staticmethod
|
|
||||||
def distribute_modulations(tensor: torch.Tensor, single_block_count: int = 38, double_blocks_count: int = 19):
|
|
||||||
"""
|
|
||||||
Distributes slices of the tensor into the block_dict as ModulationOut objects.
|
|
||||||
|
|
||||||
Args:
|
def get_modulations(self, tensor: torch.Tensor, block_type: str, *, idx: int = 0):
|
||||||
tensor (torch.Tensor): Input tensor with shape [batch_size, vectors, dim].
|
# This function slices up the modulations tensor which has the following layout:
|
||||||
"""
|
# single : num_single_blocks * 3 elements
|
||||||
batch_size, vectors, dim = tensor.shape
|
# double_img : num_double_blocks * 6 elements
|
||||||
|
# double_txt : num_double_blocks * 6 elements
|
||||||
block_dict = {}
|
# final : 2 elements
|
||||||
|
if block_type == "final":
|
||||||
# HARD CODED VALUES! lookup table for the generated vectors
|
return (tensor[:, -2:-1, :], tensor[:, -1:, :])
|
||||||
# Add 38 single mod blocks
|
single_block_count = self.params.depth_single_blocks
|
||||||
for i in range(single_block_count):
|
double_block_count = self.params.depth
|
||||||
key = f"single_blocks.{i}.modulation.lin"
|
offset = 3 * idx
|
||||||
block_dict[key] = None
|
if block_type == "single":
|
||||||
|
return ChromaModulationOut.from_offset(tensor, offset)
|
||||||
# Add 19 image double blocks
|
# Double block modulations are 6 elements so we double 3 * idx.
|
||||||
for i in range(double_blocks_count):
|
offset *= 2
|
||||||
key = f"double_blocks.{i}.img_mod.lin"
|
if block_type in {"double_img", "double_txt"}:
|
||||||
block_dict[key] = None
|
# Advance past the single block modulations.
|
||||||
|
offset += 3 * single_block_count
|
||||||
# Add 19 text double blocks
|
if block_type == "double_txt":
|
||||||
for i in range(double_blocks_count):
|
# Advance past the double block img modulations.
|
||||||
key = f"double_blocks.{i}.txt_mod.lin"
|
offset += 6 * double_block_count
|
||||||
block_dict[key] = None
|
return (
|
||||||
|
ChromaModulationOut.from_offset(tensor, offset),
|
||||||
# Add the final layer
|
ChromaModulationOut.from_offset(tensor, offset + 3),
|
||||||
block_dict["final_layer.adaLN_modulation.1"] = None
|
|
||||||
# # 6.2b version
|
|
||||||
# block_dict["lite_double_blocks.4.img_mod.lin"] = None
|
|
||||||
# block_dict["lite_double_blocks.4.txt_mod.lin"] = None
|
|
||||||
|
|
||||||
|
|
||||||
idx = 0 # Index to keep track of the vector slices
|
|
||||||
|
|
||||||
for key in block_dict.keys():
|
|
||||||
if "single_blocks" in key:
|
|
||||||
# Single block: 1 ModulationOut
|
|
||||||
block_dict[key] = ModulationOut(
|
|
||||||
shift=tensor[:, idx:idx+1, :],
|
|
||||||
scale=tensor[:, idx+1:idx+2, :],
|
|
||||||
gate=tensor[:, idx+2:idx+3, :]
|
|
||||||
)
|
)
|
||||||
idx += 3 # Advance by 3 vectors
|
raise ValueError("Bad block_type")
|
||||||
|
|
||||||
elif "img_mod" in key:
|
|
||||||
# Double block: List of 2 ModulationOut
|
|
||||||
double_block = []
|
|
||||||
for _ in range(2): # Create 2 ModulationOut objects
|
|
||||||
double_block.append(
|
|
||||||
ModulationOut(
|
|
||||||
shift=tensor[:, idx:idx+1, :],
|
|
||||||
scale=tensor[:, idx+1:idx+2, :],
|
|
||||||
gate=tensor[:, idx+2:idx+3, :]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
idx += 3 # Advance by 3 vectors per ModulationOut
|
|
||||||
block_dict[key] = double_block
|
|
||||||
|
|
||||||
elif "txt_mod" in key:
|
|
||||||
# Double block: List of 2 ModulationOut
|
|
||||||
double_block = []
|
|
||||||
for _ in range(2): # Create 2 ModulationOut objects
|
|
||||||
double_block.append(
|
|
||||||
ModulationOut(
|
|
||||||
shift=tensor[:, idx:idx+1, :],
|
|
||||||
scale=tensor[:, idx+1:idx+2, :],
|
|
||||||
gate=tensor[:, idx+2:idx+3, :]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
idx += 3 # Advance by 3 vectors per ModulationOut
|
|
||||||
block_dict[key] = double_block
|
|
||||||
|
|
||||||
elif "final_layer" in key:
|
|
||||||
# Final layer: 1 ModulationOut
|
|
||||||
block_dict[key] = [
|
|
||||||
tensor[:, idx:idx+1, :],
|
|
||||||
tensor[:, idx+1:idx+2, :],
|
|
||||||
]
|
|
||||||
idx += 2 # Advance by 2 vectors
|
|
||||||
|
|
||||||
# elif "lite_double_blocks.4.img_mod" in key:
|
|
||||||
# # Double block: List of 2 ModulationOut
|
|
||||||
# double_block = []
|
|
||||||
# for _ in range(2): # Create 2 ModulationOut objects
|
|
||||||
# double_block.append(
|
|
||||||
# ModulationOut(
|
|
||||||
# shift=tensor[:, idx:idx+1, :],
|
|
||||||
# scale=tensor[:, idx+1:idx+2, :],
|
|
||||||
# gate=tensor[:, idx+2:idx+3, :]
|
|
||||||
# )
|
|
||||||
# )
|
|
||||||
# idx += 3 # Advance by 3 vectors per ModulationOut
|
|
||||||
# block_dict[key] = double_block
|
|
||||||
|
|
||||||
# elif "lite_double_blocks.4.txt_mod" in key:
|
|
||||||
# # Double block: List of 2 ModulationOut
|
|
||||||
# double_block = []
|
|
||||||
# for _ in range(2): # Create 2 ModulationOut objects
|
|
||||||
# double_block.append(
|
|
||||||
# ModulationOut(
|
|
||||||
# shift=tensor[:, idx:idx+1, :],
|
|
||||||
# scale=tensor[:, idx+1:idx+2, :],
|
|
||||||
# gate=tensor[:, idx+2:idx+3, :]
|
|
||||||
# )
|
|
||||||
# )
|
|
||||||
# idx += 3 # Advance by 3 vectors per ModulationOut
|
|
||||||
# block_dict[key] = double_block
|
|
||||||
|
|
||||||
return block_dict
|
|
||||||
|
|
||||||
def forward_orig(
|
def forward_orig(
|
||||||
self,
|
self,
|
||||||
@ -257,8 +183,6 @@ class Chroma(nn.Module):
|
|||||||
|
|
||||||
mod_vectors = self.distilled_guidance_layer(input_vec)
|
mod_vectors = self.distilled_guidance_layer(input_vec)
|
||||||
|
|
||||||
mod_vectors_dict = self.distribute_modulations(mod_vectors, 38, 19)
|
|
||||||
|
|
||||||
txt = self.txt_in(txt)
|
txt = self.txt_in(txt)
|
||||||
|
|
||||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||||
@ -267,21 +191,10 @@ class Chroma(nn.Module):
|
|||||||
blocks_replace = patches_replace.get("dit", {})
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
for i, block in enumerate(self.double_blocks):
|
for i, block in enumerate(self.double_blocks):
|
||||||
if i not in self.skip_mmdit:
|
if i not in self.skip_mmdit:
|
||||||
guidance_index = i
|
double_mod = (
|
||||||
# if lite we change block 4 guidance with lite guidance
|
self.get_modulations(mod_vectors, "double_img", idx=i),
|
||||||
# and offset the guidance by 11 blocks after block 4
|
self.get_modulations(mod_vectors, "double_txt", idx=i),
|
||||||
# if self.lite and i == 4:
|
)
|
||||||
# img_mod = mod_vectors_dict[f"lite_double_blocks.4.img_mod.lin"]
|
|
||||||
# txt_mod = mod_vectors_dict[f"lite_double_blocks.4.txt_mod.lin"]
|
|
||||||
# elif self.lite and i > 4:
|
|
||||||
# guidance_index = i + 11
|
|
||||||
# img_mod = mod_vectors_dict[f"double_blocks.{guidance_index}.img_mod.lin"]
|
|
||||||
# txt_mod = mod_vectors_dict[f"double_blocks.{guidance_index}.txt_mod.lin"]
|
|
||||||
# else:
|
|
||||||
img_mod = mod_vectors_dict[f"double_blocks.{guidance_index}.img_mod.lin"]
|
|
||||||
txt_mod = mod_vectors_dict[f"double_blocks.{guidance_index}.txt_mod.lin"]
|
|
||||||
double_mod = [img_mod, txt_mod]
|
|
||||||
|
|
||||||
if ("double_block", i) in blocks_replace:
|
if ("double_block", i) in blocks_replace:
|
||||||
def block_wrap(args):
|
def block_wrap(args):
|
||||||
out = {}
|
out = {}
|
||||||
@ -318,7 +231,7 @@ class Chroma(nn.Module):
|
|||||||
|
|
||||||
for i, block in enumerate(self.single_blocks):
|
for i, block in enumerate(self.single_blocks):
|
||||||
if i not in self.skip_dit:
|
if i not in self.skip_dit:
|
||||||
single_mod = mod_vectors_dict[f"single_blocks.{i}.modulation.lin"]
|
single_mod = self.get_modulations(mod_vectors, "single", idx=i)
|
||||||
if ("single_block", i) in blocks_replace:
|
if ("single_block", i) in blocks_replace:
|
||||||
def block_wrap(args):
|
def block_wrap(args):
|
||||||
out = {}
|
out = {}
|
||||||
@ -345,7 +258,7 @@ class Chroma(nn.Module):
|
|||||||
img[:, txt.shape[1] :, ...] += add
|
img[:, txt.shape[1] :, ...] += add
|
||||||
|
|
||||||
img = img[:, txt.shape[1] :, ...]
|
img = img[:, txt.shape[1] :, ...]
|
||||||
final_mod = mod_vectors_dict["final_layer.adaLN_modulation.1"]
|
final_mod = self.get_modulations(mod_vectors, "final")
|
||||||
img = self.final_layer(img, vec=final_mod) # (N, T, patch_size ** 2 * out_channels)
|
img = self.final_layer(img, vec=final_mod) # (N, T, patch_size ** 2 * out_channels)
|
||||||
return img
|
return img
|
||||||
|
|
||||||
|
@ -1049,8 +1049,6 @@ class Hunyuan3Dv2(BaseModel):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
class Chroma(BaseModel):
|
class Chroma(BaseModel):
|
||||||
chroma_model_mode=False
|
|
||||||
|
|
||||||
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
|
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
|
||||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.chroma.model.Chroma)
|
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.chroma.model.Chroma)
|
||||||
|
|
||||||
@ -1098,6 +1096,15 @@ class Chroma(BaseModel):
|
|||||||
if cross_attn is not None:
|
if cross_attn is not None:
|
||||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||||
# upscale the attention mask, since now we
|
# upscale the attention mask, since now we
|
||||||
|
attention_mask = kwargs.get("attention_mask", None)
|
||||||
|
if attention_mask is not None:
|
||||||
|
shape = kwargs["noise"].shape
|
||||||
|
mask_ref_size = kwargs["attention_mask_img_shape"]
|
||||||
|
# the model will pad to the patch size, and then divide
|
||||||
|
# essentially dividing and rounding up
|
||||||
|
(h_tok, w_tok) = (math.ceil(shape[2] / self.diffusion_model.patch_size), math.ceil(shape[3] / self.diffusion_model.patch_size))
|
||||||
|
attention_mask = utils.upscale_dit_mask(attention_mask, mask_ref_size, (h_tok, w_tok))
|
||||||
|
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
|
||||||
guidance = 0.0
|
guidance = 0.0
|
||||||
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor((guidance,)))
|
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor((guidance,)))
|
||||||
return out
|
return out
|
||||||
|
@ -1025,14 +1025,22 @@ class Chroma(supported_models_base.BASE):
|
|||||||
"multiplier": 1.0,
|
"multiplier": 1.0,
|
||||||
"shift": 1.0,
|
"shift": 1.0,
|
||||||
}
|
}
|
||||||
|
|
||||||
latent_format = comfy.latent_formats.Flux
|
latent_format = comfy.latent_formats.Flux
|
||||||
memory_usage_factor = 2.8
|
|
||||||
|
memory_usage_factor = 1.8
|
||||||
|
|
||||||
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||||
|
|
||||||
def get_model(self, state_dict, prefix="", device=None):
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
out = model_base.Chroma(self, model_type=model_base.ModelType.FLUX, device=device)
|
out = model_base.Chroma(self, model_type=model_base.ModelType.FLUX, device=device)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
def clip_target(self, state_dict={}):
|
||||||
|
pref = self.text_encoder_key_prefix[0]
|
||||||
|
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
||||||
|
return supported_models_base.ClipTarget(comfy.text_encoders.chroma.ChromaTokenizer, comfy.text_encoders.chroma.chroma_te(**t5_detect))
|
||||||
|
|
||||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, Hunyuan3Dv2mini, Hunyuan3Dv2, Chroma]
|
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, Hunyuan3Dv2mini, Hunyuan3Dv2, Chroma]
|
||||||
|
|
||||||
models += [SVD_img2vid]
|
models += [SVD_img2vid]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user