get_mdulations added from blepping and minor changes

This commit is contained in:
silveroxides 2025-03-25 21:38:11 +01:00
parent 9f70cfbc42
commit f04b502ab6
3 changed files with 60 additions and 132 deletions

View File

@ -39,6 +39,16 @@ class ChromaParams:
n_layers: int n_layers: int
class ChromaModulationOut(ModulationOut):
@classmethod
def from_offset(cls, tensor: torch.Tensor, offset: int = 0) -> ModulationOut:
return cls(
shift=tensor[:, offset : offset + 1, :],
scale=tensor[:, offset + 1 : offset + 2, :],
gate=tensor[:, offset + 2 : offset + 3, :],
)
class Chroma(nn.Module): class Chroma(nn.Module):
""" """
Transformer model for flow matching on sequences. Transformer model for flow matching on sequences.
@ -108,118 +118,34 @@ class Chroma(nn.Module):
self.skip_mmdit = [] self.skip_mmdit = []
self.skip_dit = [] self.skip_dit = []
self.lite = False self.lite = False
@staticmethod
def distribute_modulations(tensor: torch.Tensor, single_block_count: int = 38, double_blocks_count: int = 19):
"""
Distributes slices of the tensor into the block_dict as ModulationOut objects.
Args: def get_modulations(self, tensor: torch.Tensor, block_type: str, *, idx: int = 0):
tensor (torch.Tensor): Input tensor with shape [batch_size, vectors, dim]. # This function slices up the modulations tensor which has the following layout:
""" # single : num_single_blocks * 3 elements
batch_size, vectors, dim = tensor.shape # double_img : num_double_blocks * 6 elements
# double_txt : num_double_blocks * 6 elements
block_dict = {} # final : 2 elements
if block_type == "final":
# HARD CODED VALUES! lookup table for the generated vectors return (tensor[:, -2:-1, :], tensor[:, -1:, :])
# Add 38 single mod blocks single_block_count = self.params.depth_single_blocks
for i in range(single_block_count): double_block_count = self.params.depth
key = f"single_blocks.{i}.modulation.lin" offset = 3 * idx
block_dict[key] = None if block_type == "single":
return ChromaModulationOut.from_offset(tensor, offset)
# Add 19 image double blocks # Double block modulations are 6 elements so we double 3 * idx.
for i in range(double_blocks_count): offset *= 2
key = f"double_blocks.{i}.img_mod.lin" if block_type in {"double_img", "double_txt"}:
block_dict[key] = None # Advance past the single block modulations.
offset += 3 * single_block_count
# Add 19 text double blocks if block_type == "double_txt":
for i in range(double_blocks_count): # Advance past the double block img modulations.
key = f"double_blocks.{i}.txt_mod.lin" offset += 6 * double_block_count
block_dict[key] = None return (
ChromaModulationOut.from_offset(tensor, offset),
# Add the final layer ChromaModulationOut.from_offset(tensor, offset + 3),
block_dict["final_layer.adaLN_modulation.1"] = None
# # 6.2b version
# block_dict["lite_double_blocks.4.img_mod.lin"] = None
# block_dict["lite_double_blocks.4.txt_mod.lin"] = None
idx = 0 # Index to keep track of the vector slices
for key in block_dict.keys():
if "single_blocks" in key:
# Single block: 1 ModulationOut
block_dict[key] = ModulationOut(
shift=tensor[:, idx:idx+1, :],
scale=tensor[:, idx+1:idx+2, :],
gate=tensor[:, idx+2:idx+3, :]
) )
idx += 3 # Advance by 3 vectors raise ValueError("Bad block_type")
elif "img_mod" in key:
# Double block: List of 2 ModulationOut
double_block = []
for _ in range(2): # Create 2 ModulationOut objects
double_block.append(
ModulationOut(
shift=tensor[:, idx:idx+1, :],
scale=tensor[:, idx+1:idx+2, :],
gate=tensor[:, idx+2:idx+3, :]
)
)
idx += 3 # Advance by 3 vectors per ModulationOut
block_dict[key] = double_block
elif "txt_mod" in key:
# Double block: List of 2 ModulationOut
double_block = []
for _ in range(2): # Create 2 ModulationOut objects
double_block.append(
ModulationOut(
shift=tensor[:, idx:idx+1, :],
scale=tensor[:, idx+1:idx+2, :],
gate=tensor[:, idx+2:idx+3, :]
)
)
idx += 3 # Advance by 3 vectors per ModulationOut
block_dict[key] = double_block
elif "final_layer" in key:
# Final layer: 1 ModulationOut
block_dict[key] = [
tensor[:, idx:idx+1, :],
tensor[:, idx+1:idx+2, :],
]
idx += 2 # Advance by 2 vectors
# elif "lite_double_blocks.4.img_mod" in key:
# # Double block: List of 2 ModulationOut
# double_block = []
# for _ in range(2): # Create 2 ModulationOut objects
# double_block.append(
# ModulationOut(
# shift=tensor[:, idx:idx+1, :],
# scale=tensor[:, idx+1:idx+2, :],
# gate=tensor[:, idx+2:idx+3, :]
# )
# )
# idx += 3 # Advance by 3 vectors per ModulationOut
# block_dict[key] = double_block
# elif "lite_double_blocks.4.txt_mod" in key:
# # Double block: List of 2 ModulationOut
# double_block = []
# for _ in range(2): # Create 2 ModulationOut objects
# double_block.append(
# ModulationOut(
# shift=tensor[:, idx:idx+1, :],
# scale=tensor[:, idx+1:idx+2, :],
# gate=tensor[:, idx+2:idx+3, :]
# )
# )
# idx += 3 # Advance by 3 vectors per ModulationOut
# block_dict[key] = double_block
return block_dict
def forward_orig( def forward_orig(
self, self,
@ -257,8 +183,6 @@ class Chroma(nn.Module):
mod_vectors = self.distilled_guidance_layer(input_vec) mod_vectors = self.distilled_guidance_layer(input_vec)
mod_vectors_dict = self.distribute_modulations(mod_vectors, 38, 19)
txt = self.txt_in(txt) txt = self.txt_in(txt)
ids = torch.cat((txt_ids, img_ids), dim=1) ids = torch.cat((txt_ids, img_ids), dim=1)
@ -267,21 +191,10 @@ class Chroma(nn.Module):
blocks_replace = patches_replace.get("dit", {}) blocks_replace = patches_replace.get("dit", {})
for i, block in enumerate(self.double_blocks): for i, block in enumerate(self.double_blocks):
if i not in self.skip_mmdit: if i not in self.skip_mmdit:
guidance_index = i double_mod = (
# if lite we change block 4 guidance with lite guidance self.get_modulations(mod_vectors, "double_img", idx=i),
# and offset the guidance by 11 blocks after block 4 self.get_modulations(mod_vectors, "double_txt", idx=i),
# if self.lite and i == 4: )
# img_mod = mod_vectors_dict[f"lite_double_blocks.4.img_mod.lin"]
# txt_mod = mod_vectors_dict[f"lite_double_blocks.4.txt_mod.lin"]
# elif self.lite and i > 4:
# guidance_index = i + 11
# img_mod = mod_vectors_dict[f"double_blocks.{guidance_index}.img_mod.lin"]
# txt_mod = mod_vectors_dict[f"double_blocks.{guidance_index}.txt_mod.lin"]
# else:
img_mod = mod_vectors_dict[f"double_blocks.{guidance_index}.img_mod.lin"]
txt_mod = mod_vectors_dict[f"double_blocks.{guidance_index}.txt_mod.lin"]
double_mod = [img_mod, txt_mod]
if ("double_block", i) in blocks_replace: if ("double_block", i) in blocks_replace:
def block_wrap(args): def block_wrap(args):
out = {} out = {}
@ -318,7 +231,7 @@ class Chroma(nn.Module):
for i, block in enumerate(self.single_blocks): for i, block in enumerate(self.single_blocks):
if i not in self.skip_dit: if i not in self.skip_dit:
single_mod = mod_vectors_dict[f"single_blocks.{i}.modulation.lin"] single_mod = self.get_modulations(mod_vectors, "single", idx=i)
if ("single_block", i) in blocks_replace: if ("single_block", i) in blocks_replace:
def block_wrap(args): def block_wrap(args):
out = {} out = {}
@ -345,7 +258,7 @@ class Chroma(nn.Module):
img[:, txt.shape[1] :, ...] += add img[:, txt.shape[1] :, ...] += add
img = img[:, txt.shape[1] :, ...] img = img[:, txt.shape[1] :, ...]
final_mod = mod_vectors_dict["final_layer.adaLN_modulation.1"] final_mod = self.get_modulations(mod_vectors, "final")
img = self.final_layer(img, vec=final_mod) # (N, T, patch_size ** 2 * out_channels) img = self.final_layer(img, vec=final_mod) # (N, T, patch_size ** 2 * out_channels)
return img return img

View File

@ -1049,8 +1049,6 @@ class Hunyuan3Dv2(BaseModel):
return out return out
class Chroma(BaseModel): class Chroma(BaseModel):
chroma_model_mode=False
def __init__(self, model_config, model_type=ModelType.FLUX, device=None): def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.chroma.model.Chroma) super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.chroma.model.Chroma)
@ -1098,6 +1096,15 @@ class Chroma(BaseModel):
if cross_attn is not None: if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
# upscale the attention mask, since now we # upscale the attention mask, since now we
attention_mask = kwargs.get("attention_mask", None)
if attention_mask is not None:
shape = kwargs["noise"].shape
mask_ref_size = kwargs["attention_mask_img_shape"]
# the model will pad to the patch size, and then divide
# essentially dividing and rounding up
(h_tok, w_tok) = (math.ceil(shape[2] / self.diffusion_model.patch_size), math.ceil(shape[3] / self.diffusion_model.patch_size))
attention_mask = utils.upscale_dit_mask(attention_mask, mask_ref_size, (h_tok, w_tok))
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
guidance = 0.0 guidance = 0.0
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor((guidance,))) out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor((guidance,)))
return out return out

View File

@ -1025,14 +1025,22 @@ class Chroma(supported_models_base.BASE):
"multiplier": 1.0, "multiplier": 1.0,
"shift": 1.0, "shift": 1.0,
} }
latent_format = comfy.latent_formats.Flux latent_format = comfy.latent_formats.Flux
memory_usage_factor = 2.8
memory_usage_factor = 1.8
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32] supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
def get_model(self, state_dict, prefix="", device=None): def get_model(self, state_dict, prefix="", device=None):
out = model_base.Chroma(self, model_type=model_base.ModelType.FLUX, device=device) out = model_base.Chroma(self, model_type=model_base.ModelType.FLUX, device=device)
return out return out
def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.chroma.ChromaTokenizer, comfy.text_encoders.chroma.chroma_te(**t5_detect))
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, Hunyuan3Dv2mini, Hunyuan3Dv2, Chroma] models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, Hunyuan3Dv2mini, Hunyuan3Dv2, Chroma]
models += [SVD_img2vid] models += [SVD_img2vid]