From 83dbac28ebaac7c3230bf48b647826621954b39f Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 20 Aug 2024 10:00:16 -0400 Subject: [PATCH] Properly set if clip text pooled projection instead of using hack. --- comfy/clip_model.py | 1 - comfy/sd1_clip.py | 7 +++++-- comfy/sdxl_clip.py | 4 ++-- comfy/text_encoders/sd2_clip.py | 2 +- 4 files changed, 8 insertions(+), 6 deletions(-) diff --git a/comfy/clip_model.py b/comfy/clip_model.py index 3c67b737..bdf1e2b1 100644 --- a/comfy/clip_model.py +++ b/comfy/clip_model.py @@ -123,7 +123,6 @@ class CLIPTextModel(torch.nn.Module): self.text_model = CLIPTextModel_(config_dict, dtype, device, operations) embed_dim = config_dict["hidden_size"] self.text_projection = operations.Linear(embed_dim, embed_dim, bias=False, dtype=dtype, device=device) - self.text_projection.weight.copy_(torch.eye(embed_dim)) self.dtype = dtype def get_input_embeddings(self): diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index dc8413b7..676653f7 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -75,7 +75,6 @@ class ClipTokenWeightEncoder: return r class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): - """Uses the CLIP transformer encoder for text (from huggingface)""" LAYERS = [ "last", "pooled", @@ -556,8 +555,12 @@ class SD1Tokenizer: def state_dict(self): return {} +class SD1CheckpointClipModel(SDClipModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + super().__init__(device=device, return_projected_pooled=False, dtype=dtype, model_options=model_options) + class SD1ClipModel(torch.nn.Module): - def __init__(self, device="cpu", dtype=None, model_options={}, clip_name="l", clip_model=SDClipModel, name=None, **kwargs): + def __init__(self, device="cpu", dtype=None, model_options={}, clip_name="l", clip_model=SD1CheckpointClipModel, name=None, **kwargs): super().__init__() if name is not None: diff --git a/comfy/sdxl_clip.py b/comfy/sdxl_clip.py index 860900cc..a0145caa 100644 --- a/comfy/sdxl_clip.py +++ b/comfy/sdxl_clip.py @@ -10,7 +10,7 @@ class SDXLClipG(sd1_clip.SDClipModel): textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json") super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, - special_tokens={"start": 49406, "end": 49407, "pad": 0}, layer_norm_hidden_state=False, model_options=model_options) + special_tokens={"start": 49406, "end": 49407, "pad": 0}, layer_norm_hidden_state=False, return_projected_pooled=True, model_options=model_options) def load_sd(self, sd): return super().load_sd(sd) @@ -82,7 +82,7 @@ class StableCascadeClipG(sd1_clip.SDClipModel): def __init__(self, device="cpu", max_length=77, freeze=True, layer="hidden", layer_idx=-1, dtype=None, model_options={}): textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json") super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, - special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=False, enable_attention_masks=True, model_options=model_options) + special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=False, enable_attention_masks=True, return_projected_pooled=True, model_options=model_options) def load_sd(self, sd): return super().load_sd(sd) diff --git a/comfy/text_encoders/sd2_clip.py b/comfy/text_encoders/sd2_clip.py index 0c98cd85..31fc8986 100644 --- a/comfy/text_encoders/sd2_clip.py +++ b/comfy/text_encoders/sd2_clip.py @@ -8,7 +8,7 @@ class SD2ClipHModel(sd1_clip.SDClipModel): layer_idx=-2 textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd2_clip_config.json") - super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0}, model_options=model_options) + super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0}, return_projected_pooled=True, model_options=model_options) class SD2ClipHTokenizer(sd1_clip.SDTokenizer): def __init__(self, tokenizer_path=None, embedding_directory=None, tokenizer_data={}):