From 8f0009aad0591ceee59a147738aa227187b07898 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 21 Nov 2024 08:38:23 -0500 Subject: [PATCH] Support new flux model variants. --- comfy/clip_model.py | 44 +++++++++++++++++++++++-------- comfy/clip_vision.py | 14 ++++++---- comfy/clip_vision_siglip_384.json | 13 +++++++++ comfy/ldm/flux/model.py | 9 ++++--- comfy/ldm/flux/redux.py | 25 ++++++++++++++++++ comfy/lora_convert.py | 17 ++++++++++++ comfy/model_base.py | 32 ++++++++++++++++++++++ comfy/model_detection.py | 6 +++++ comfy/sd.py | 6 +++++ 9 files changed, 147 insertions(+), 19 deletions(-) create mode 100644 comfy/clip_vision_siglip_384.json create mode 100644 comfy/ldm/flux/redux.py create mode 100644 comfy/lora_convert.py diff --git a/comfy/clip_model.py b/comfy/clip_model.py index 42cdc4f6..23ddea9c 100644 --- a/comfy/clip_model.py +++ b/comfy/clip_model.py @@ -23,6 +23,7 @@ class CLIPAttention(torch.nn.Module): ACTIVATIONS = {"quick_gelu": lambda a: a * torch.sigmoid(1.702 * a), "gelu": torch.nn.functional.gelu, + "gelu_pytorch_tanh": lambda a: torch.nn.functional.gelu(a, approximate="tanh"), } class CLIPMLP(torch.nn.Module): @@ -139,27 +140,35 @@ class CLIPTextModel(torch.nn.Module): class CLIPVisionEmbeddings(torch.nn.Module): - def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, dtype=None, device=None, operations=None): + def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, model_type="", dtype=None, device=None, operations=None): super().__init__() - self.class_embedding = torch.nn.Parameter(torch.empty(embed_dim, dtype=dtype, device=device)) + + num_patches = (image_size // patch_size) ** 2 + if model_type == "siglip_vision_model": + self.class_embedding = None + patch_bias = True + else: + num_patches = num_patches + 1 + self.class_embedding = torch.nn.Parameter(torch.empty(embed_dim, dtype=dtype, device=device)) + patch_bias = False self.patch_embedding = operations.Conv2d( in_channels=num_channels, out_channels=embed_dim, kernel_size=patch_size, stride=patch_size, - bias=False, + bias=patch_bias, dtype=dtype, device=device ) - num_patches = (image_size // patch_size) ** 2 - num_positions = num_patches + 1 - self.position_embedding = operations.Embedding(num_positions, embed_dim, dtype=dtype, device=device) + self.position_embedding = operations.Embedding(num_patches, embed_dim, dtype=dtype, device=device) def forward(self, pixel_values): embeds = self.patch_embedding(pixel_values).flatten(2).transpose(1, 2) - return torch.cat([comfy.ops.cast_to_input(self.class_embedding, embeds).expand(pixel_values.shape[0], 1, -1), embeds], dim=1) + comfy.ops.cast_to_input(self.position_embedding.weight, embeds) + if self.class_embedding is not None: + embeds = torch.cat([comfy.ops.cast_to_input(self.class_embedding, embeds).expand(pixel_values.shape[0], 1, -1), embeds], dim=1) + return embeds + comfy.ops.cast_to_input(self.position_embedding.weight, embeds) class CLIPVision(torch.nn.Module): @@ -170,9 +179,15 @@ class CLIPVision(torch.nn.Module): heads = config_dict["num_attention_heads"] intermediate_size = config_dict["intermediate_size"] intermediate_activation = config_dict["hidden_act"] + model_type = config_dict["model_type"] - self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], dtype=dtype, device=device, operations=operations) - self.pre_layrnorm = operations.LayerNorm(embed_dim) + self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], model_type=model_type, dtype=dtype, device=device, operations=operations) + if model_type == "siglip_vision_model": + self.pre_layrnorm = lambda a: a + self.output_layernorm = True + else: + self.pre_layrnorm = operations.LayerNorm(embed_dim) + self.output_layernorm = False self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations) self.post_layernorm = operations.LayerNorm(embed_dim) @@ -181,14 +196,21 @@ class CLIPVision(torch.nn.Module): x = self.pre_layrnorm(x) #TODO: attention_mask? x, i = self.encoder(x, mask=None, intermediate_output=intermediate_output) - pooled_output = self.post_layernorm(x[:, 0, :]) + if self.output_layernorm: + x = self.post_layernorm(x) + pooled_output = x + else: + pooled_output = self.post_layernorm(x[:, 0, :]) return x, i, pooled_output class CLIPVisionModelProjection(torch.nn.Module): def __init__(self, config_dict, dtype, device, operations): super().__init__() self.vision_model = CLIPVision(config_dict, dtype, device, operations) - self.visual_projection = operations.Linear(config_dict["hidden_size"], config_dict["projection_dim"], bias=False) + if "projection_dim" in config_dict: + self.visual_projection = operations.Linear(config_dict["hidden_size"], config_dict["projection_dim"], bias=False) + else: + self.visual_projection = lambda a: a def forward(self, *args, **kwargs): x = self.vision_model(*args, **kwargs) diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py index 64392e27..ed917cfb 100644 --- a/comfy/clip_vision.py +++ b/comfy/clip_vision.py @@ -16,9 +16,9 @@ class Output: def __setitem__(self, key, item): setattr(self, key, item) -def clip_preprocess(image, size=224): - mean = torch.tensor([ 0.48145466,0.4578275,0.40821073], device=image.device, dtype=image.dtype) - std = torch.tensor([0.26862954,0.26130258,0.27577711], device=image.device, dtype=image.dtype) +def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]): + mean = torch.tensor(mean, device=image.device, dtype=image.dtype) + std = torch.tensor(std, device=image.device, dtype=image.dtype) image = image.movedim(-1, 1) if not (image.shape[2] == size and image.shape[3] == size): scale = (size / min(image.shape[2], image.shape[3])) @@ -35,6 +35,8 @@ class ClipVisionModel(): config = json.load(f) self.image_size = config.get("image_size", 224) + self.image_mean = config.get("image_mean", [0.48145466, 0.4578275, 0.40821073]) + self.image_std = config.get("image_std", [0.26862954, 0.26130258, 0.27577711]) self.load_device = comfy.model_management.text_encoder_device() offload_device = comfy.model_management.text_encoder_offload_device() self.dtype = comfy.model_management.text_encoder_dtype(self.load_device) @@ -51,7 +53,7 @@ class ClipVisionModel(): def encode_image(self, image): comfy.model_management.load_model_gpu(self.patcher) - pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size).float() + pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std).float() out = self.model(pixel_values=pixel_values, intermediate_output=-2) outputs = Output() @@ -94,7 +96,9 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False): elif "vision_model.encoder.layers.30.layer_norm1.weight" in sd: json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_h.json") elif "vision_model.encoder.layers.22.layer_norm1.weight" in sd: - if sd["vision_model.embeddings.position_embedding.weight"].shape[0] == 577: + if sd["vision_model.encoder.layers.0.layer_norm1.weight"].shape[0] == 1152: + json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_384.json") + elif sd["vision_model.embeddings.position_embedding.weight"].shape[0] == 577: json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336.json") else: json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json") diff --git a/comfy/clip_vision_siglip_384.json b/comfy/clip_vision_siglip_384.json new file mode 100644 index 00000000..532e03ac --- /dev/null +++ b/comfy/clip_vision_siglip_384.json @@ -0,0 +1,13 @@ +{ + "num_channels": 3, + "hidden_act": "gelu_pytorch_tanh", + "hidden_size": 1152, + "image_size": 384, + "intermediate_size": 4304, + "model_type": "siglip_vision_model", + "num_attention_heads": 16, + "num_hidden_layers": 27, + "patch_size": 14, + "image_mean": [0.5, 0.5, 0.5], + "image_std": [0.5, 0.5, 0.5] +} diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index ae1ed109..97ad8ffe 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -20,6 +20,7 @@ import comfy.ldm.common_dit @dataclass class FluxParams: in_channels: int + out_channels: int vec_in_dim: int context_in_dim: int hidden_size: int @@ -29,6 +30,7 @@ class FluxParams: depth_single_blocks: int axes_dim: list theta: int + patch_size: int qkv_bias: bool guidance_embed: bool @@ -43,8 +45,9 @@ class Flux(nn.Module): self.dtype = dtype params = FluxParams(**kwargs) self.params = params - self.in_channels = params.in_channels * 2 * 2 - self.out_channels = self.in_channels + self.patch_size = params.patch_size + self.in_channels = params.in_channels * params.patch_size * params.patch_size + self.out_channels = params.out_channels * params.patch_size * params.patch_size if params.hidden_size % params.num_heads != 0: raise ValueError( f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}" @@ -165,7 +168,7 @@ class Flux(nn.Module): def forward(self, x, timestep, context, y, guidance, control=None, transformer_options={}, **kwargs): bs, c, h, w = x.shape - patch_size = 2 + patch_size = self.patch_size x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size)) img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size) diff --git a/comfy/ldm/flux/redux.py b/comfy/ldm/flux/redux.py new file mode 100644 index 00000000..527e8316 --- /dev/null +++ b/comfy/ldm/flux/redux.py @@ -0,0 +1,25 @@ +import torch +import comfy.ops + +ops = comfy.ops.manual_cast + +class ReduxImageEncoder(torch.nn.Module): + def __init__( + self, + redux_dim: int = 1152, + txt_in_features: int = 4096, + device=None, + dtype=None, + ) -> None: + super().__init__() + + self.redux_dim = redux_dim + self.device = device + self.dtype = dtype + + self.redux_up = ops.Linear(redux_dim, txt_in_features * 3, dtype=dtype) + self.redux_down = ops.Linear(txt_in_features * 3, txt_in_features, dtype=dtype) + + def forward(self, sigclip_embeds) -> torch.Tensor: + projected_x = self.redux_down(torch.nn.functional.silu(self.redux_up(sigclip_embeds))) + return projected_x diff --git a/comfy/lora_convert.py b/comfy/lora_convert.py new file mode 100644 index 00000000..05032c69 --- /dev/null +++ b/comfy/lora_convert.py @@ -0,0 +1,17 @@ +import torch + + +def convert_lora_bfl_control(sd): #BFL loras for Flux + sd_out = {} + for k in sd: + k_to = "diffusion_model.{}".format(k.replace(".lora_B.bias", ".diff_b").replace("_norm.scale", "_norm.scale.set_weight")) + sd_out[k_to] = sd[k] + + sd_out["diffusion_model.img_in.reshape_weight"] = torch.tensor([sd["img_in.lora_B.weight"].shape[0], sd["img_in.lora_A.weight"].shape[1]]) + return sd_out + + +def convert_lora(sd): + if "img_in.lora_A.weight" in sd and "single_blocks.0.norm.key_norm.scale" in sd: + return convert_lora_bfl_control(sd) + return sd diff --git a/comfy/model_base.py b/comfy/model_base.py index 7e92ca10..6f1aa570 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -710,6 +710,38 @@ class Flux(BaseModel): def __init__(self, model_config, model_type=ModelType.FLUX, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.flux.model.Flux) + def concat_cond(self, **kwargs): + num_channels = self.diffusion_model.img_in.weight.shape[1] // (self.diffusion_model.patch_size * self.diffusion_model.patch_size) + out_channels = self.model_config.unet_config["out_channels"] + + if num_channels <= out_channels: + return None + + image = kwargs.get("concat_latent_image", None) + noise = kwargs.get("noise", None) + device = kwargs["device"] + + if image is None: + image = torch.zeros_like(noise) + + image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") + image = utils.resize_to_batch_size(image, noise.shape[0]) + image = self.process_latent_in(image) + if num_channels <= out_channels * 2: + return image + + #inpaint model + mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None)) + if mask is None: + mask = torch.ones_like(noise)[:, :1] + + mask = torch.mean(mask, dim=1, keepdim=True) + print(mask.shape) + mask = utils.common_upscale(mask.to(device), noise.shape[-1] * 8, noise.shape[-2] * 8, "bilinear", "center") + mask = mask.view(mask.shape[0], mask.shape[2] // 8, 8, mask.shape[3] // 8, 8).permute(0, 2, 4, 1, 3).reshape(mask.shape[0], -1, mask.shape[2] // 8, mask.shape[3] // 8) + mask = utils.resize_to_batch_size(mask, noise.shape[0]) + return torch.cat((image, mask), dim=1) + def encode_adm(self, **kwargs): return kwargs["pooled_output"] diff --git a/comfy/model_detection.py b/comfy/model_detection.py index b98820d8..008e4b19 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -137,6 +137,12 @@ def detect_unet_config(state_dict, key_prefix): dit_config = {} dit_config["image_model"] = "flux" dit_config["in_channels"] = 16 + patch_size = 2 + dit_config["patch_size"] = patch_size + in_key = "{}img_in.weight".format(key_prefix) + if in_key in state_dict_keys: + dit_config["in_channels"] = state_dict[in_key].shape[1] // (patch_size * patch_size) + dit_config["out_channels"] = 16 dit_config["vec_in_dim"] = 768 dit_config["context_in_dim"] = 4096 dit_config["hidden_size"] = 3072 diff --git a/comfy/sd.py b/comfy/sd.py index 95fc6d27..3dea043f 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -30,9 +30,12 @@ import comfy.text_encoders.genmo import comfy.model_patcher import comfy.lora +import comfy.lora_convert import comfy.t2i_adapter.adapter import comfy.taesd.taesd +import comfy.ldm.flux.redux + def load_lora_for_models(model, clip, lora, strength_model, strength_clip): key_map = {} if model is not None: @@ -40,6 +43,7 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip): if clip is not None: key_map = comfy.lora.model_lora_keys_clip(clip.cond_stage_model, key_map) + lora = comfy.lora_convert.convert_lora(lora) loaded = comfy.lora.load_lora(lora, key_map) if model is not None: new_modelpatcher = model.clone() @@ -433,6 +437,8 @@ def load_style_model(ckpt_path): keys = model_data.keys() if "style_embedding" in keys: model = comfy.t2i_adapter.adapter.StyleAdapter(width=1024, context_dim=768, num_head=8, n_layes=3, num_token=8) + elif "redux_down.weight" in keys: + model = comfy.ldm.flux.redux.ReduxImageEncoder() else: raise Exception("invalid style model {}".format(ckpt_path)) model.load_state_dict(model_data)