mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 02:15:17 +00:00
StableCascade CLIP model support.
This commit is contained in:
parent
667c92814e
commit
97d03ae04a
14
comfy/sd.py
14
comfy/sd.py
@ -1,4 +1,5 @@
|
|||||||
import torch
|
import torch
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
from comfy import model_management
|
from comfy import model_management
|
||||||
from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine
|
from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine
|
||||||
@ -309,8 +310,11 @@ def load_style_model(ckpt_path):
|
|||||||
model.load_state_dict(model_data)
|
model.load_state_dict(model_data)
|
||||||
return StyleModel(model)
|
return StyleModel(model)
|
||||||
|
|
||||||
|
class CLIPType(Enum):
|
||||||
|
STABLE_DIFFUSION = 1
|
||||||
|
STABLE_CASCADE = 2
|
||||||
|
|
||||||
def load_clip(ckpt_paths, embedding_directory=None):
|
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION):
|
||||||
clip_data = []
|
clip_data = []
|
||||||
for p in ckpt_paths:
|
for p in ckpt_paths:
|
||||||
clip_data.append(comfy.utils.load_torch_file(p, safe_load=True))
|
clip_data.append(comfy.utils.load_torch_file(p, safe_load=True))
|
||||||
@ -326,8 +330,12 @@ def load_clip(ckpt_paths, embedding_directory=None):
|
|||||||
clip_target.params = {}
|
clip_target.params = {}
|
||||||
if len(clip_data) == 1:
|
if len(clip_data) == 1:
|
||||||
if "text_model.encoder.layers.30.mlp.fc1.weight" in clip_data[0]:
|
if "text_model.encoder.layers.30.mlp.fc1.weight" in clip_data[0]:
|
||||||
clip_target.clip = sdxl_clip.SDXLRefinerClipModel
|
if clip_type == CLIPType.STABLE_CASCADE:
|
||||||
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
|
clip_target.clip = sdxl_clip.StableCascadeClipModel
|
||||||
|
clip_target.tokenizer = sdxl_clip.StableCascadeTokenizer
|
||||||
|
else:
|
||||||
|
clip_target.clip = sdxl_clip.SDXLRefinerClipModel
|
||||||
|
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
|
||||||
elif "text_model.encoder.layers.22.mlp.fc1.weight" in clip_data[0]:
|
elif "text_model.encoder.layers.22.mlp.fc1.weight" in clip_data[0]:
|
||||||
clip_target.clip = sd2_clip.SD2ClipModel
|
clip_target.clip = sd2_clip.SD2ClipModel
|
||||||
clip_target.tokenizer = sd2_clip.SD2Tokenizer
|
clip_target.tokenizer = sd2_clip.SD2Tokenizer
|
||||||
|
@ -67,7 +67,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
]
|
]
|
||||||
def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77,
|
def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77,
|
||||||
freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=comfy.clip_model.CLIPTextModel,
|
freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=comfy.clip_model.CLIPTextModel,
|
||||||
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True): # clip-vit-base-patch32
|
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True, enable_attention_masks=False): # clip-vit-base-patch32
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert layer in self.LAYERS
|
assert layer in self.LAYERS
|
||||||
|
|
||||||
@ -88,7 +88,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
self.special_tokens = special_tokens
|
self.special_tokens = special_tokens
|
||||||
self.text_projection = torch.nn.Parameter(torch.eye(self.transformer.get_input_embeddings().weight.shape[1]))
|
self.text_projection = torch.nn.Parameter(torch.eye(self.transformer.get_input_embeddings().weight.shape[1]))
|
||||||
self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055))
|
self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055))
|
||||||
self.enable_attention_masks = False
|
self.enable_attention_masks = enable_attention_masks
|
||||||
|
|
||||||
self.layer_norm_hidden_state = layer_norm_hidden_state
|
self.layer_norm_hidden_state = layer_norm_hidden_state
|
||||||
if layer == "hidden":
|
if layer == "hidden":
|
||||||
|
@ -64,3 +64,25 @@ class SDXLClipModel(torch.nn.Module):
|
|||||||
class SDXLRefinerClipModel(sd1_clip.SD1ClipModel):
|
class SDXLRefinerClipModel(sd1_clip.SD1ClipModel):
|
||||||
def __init__(self, device="cpu", dtype=None):
|
def __init__(self, device="cpu", dtype=None):
|
||||||
super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=SDXLClipG)
|
super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=SDXLClipG)
|
||||||
|
|
||||||
|
|
||||||
|
class StableCascadeClipGTokenizer(sd1_clip.SDTokenizer):
|
||||||
|
def __init__(self, tokenizer_path=None, embedding_directory=None):
|
||||||
|
super().__init__(tokenizer_path, pad_with_end=True, embedding_directory=embedding_directory, embedding_size=1280, embedding_key='clip_g')
|
||||||
|
|
||||||
|
class StableCascadeTokenizer(sd1_clip.SD1Tokenizer):
|
||||||
|
def __init__(self, embedding_directory=None):
|
||||||
|
super().__init__(embedding_directory=embedding_directory, clip_name="g", tokenizer=StableCascadeClipGTokenizer)
|
||||||
|
|
||||||
|
class StableCascadeClipG(sd1_clip.SDClipModel):
|
||||||
|
def __init__(self, device="cpu", max_length=77, freeze=True, layer="hidden", layer_idx=-1, dtype=None):
|
||||||
|
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json")
|
||||||
|
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype,
|
||||||
|
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=False, enable_attention_masks=True)
|
||||||
|
|
||||||
|
def load_sd(self, sd):
|
||||||
|
return super().load_sd(sd)
|
||||||
|
|
||||||
|
class StableCascadeClipModel(sd1_clip.SD1ClipModel):
|
||||||
|
def __init__(self, device="cpu", dtype=None):
|
||||||
|
super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=StableCascadeClipG)
|
||||||
|
@ -336,7 +336,7 @@ class Stable_Cascade_C(supported_models_base.BASE):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
def clip_target(self):
|
def clip_target(self):
|
||||||
return None
|
return supported_models_base.ClipTarget(sdxl_clip.StableCascadeTokenizer, sdxl_clip.StableCascadeClipModel)
|
||||||
|
|
||||||
class Stable_Cascade_B(Stable_Cascade_C):
|
class Stable_Cascade_B(Stable_Cascade_C):
|
||||||
unet_config = {
|
unet_config = {
|
||||||
|
9
nodes.py
9
nodes.py
@ -854,15 +854,20 @@ class CLIPLoader:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": { "clip_name": (folder_paths.get_filename_list("clip"), ),
|
return {"required": { "clip_name": (folder_paths.get_filename_list("clip"), ),
|
||||||
|
"type": (["stable_diffusion", "stable_cascade"], ),
|
||||||
}}
|
}}
|
||||||
RETURN_TYPES = ("CLIP",)
|
RETURN_TYPES = ("CLIP",)
|
||||||
FUNCTION = "load_clip"
|
FUNCTION = "load_clip"
|
||||||
|
|
||||||
CATEGORY = "advanced/loaders"
|
CATEGORY = "advanced/loaders"
|
||||||
|
|
||||||
def load_clip(self, clip_name):
|
def load_clip(self, clip_name, type="stable_diffusion"):
|
||||||
|
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
|
||||||
|
if type == "stable_cascade":
|
||||||
|
clip_type = comfy.sd.CLIPType.STABLE_CASCADE
|
||||||
|
|
||||||
clip_path = folder_paths.get_full_path("clip", clip_name)
|
clip_path = folder_paths.get_full_path("clip", clip_name)
|
||||||
clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type)
|
||||||
return (clip,)
|
return (clip,)
|
||||||
|
|
||||||
class DualCLIPLoader:
|
class DualCLIPLoader:
|
||||||
|
Loading…
Reference in New Issue
Block a user