2023-01-03 01:53:32 -05:00
import os
2023-12-06 15:55:09 -05:00
from transformers import CLIPTokenizer
2023-06-14 20:13:08 -04:00
import comfy . ops
2023-01-03 01:53:32 -05:00
import torch
2023-04-14 13:54:00 -04:00
import traceback
2023-04-14 15:33:43 -04:00
import zipfile
2023-07-01 12:37:23 -04:00
from . import model_management
2023-12-06 15:55:09 -05:00
import comfy . clip_model
import json
2024-03-10 11:37:08 -04:00
import logging
2024-06-19 22:42:41 +02:00
import numbers
2024-12-17 16:18:35 -05:00
import re
2023-01-03 01:53:32 -05:00
2023-11-06 13:43:50 -05:00
def gen_empty_tokens ( special_tokens , length ) :
start_token = special_tokens . get ( " start " , None )
end_token = special_tokens . get ( " end " , None )
pad_token = special_tokens . get ( " pad " )
output = [ ]
if start_token is not None :
output . append ( start_token )
if end_token is not None :
output . append ( end_token )
output + = [ pad_token ] * ( length - len ( output ) )
return output
2023-01-03 01:53:32 -05:00
class ClipTokenWeightEncoder :
def encode_token_weights ( self , token_weight_pairs ) :
2023-11-06 13:43:50 -05:00
to_encode = list ( )
max_token_len = 0
has_weights = False
2023-01-03 01:53:32 -05:00
for x in token_weight_pairs :
2023-07-01 15:07:39 -04:00
tokens = list ( map ( lambda a : a [ 0 ] , x ) )
2023-11-06 13:43:50 -05:00
max_token_len = max ( len ( tokens ) , max_token_len )
has_weights = has_weights or not all ( map ( lambda a : a [ 1 ] == 1.0 , x ) )
2023-07-01 15:07:39 -04:00
to_encode . append ( tokens )
2023-11-06 13:43:50 -05:00
sections = len ( to_encode )
if has_weights or sections == 0 :
2024-12-20 21:25:00 +01:00
if hasattr ( self , " gen_empty_tokens " ) :
to_encode . append ( self . gen_empty_tokens ( self . special_tokens , max_token_len ) )
else :
to_encode . append ( gen_empty_tokens ( self . special_tokens , max_token_len ) )
2024-12-20 13:24:55 -08:00
2024-07-10 19:31:22 -04:00
o = self . encode ( to_encode )
out , pooled = o [ : 2 ]
2023-11-06 13:43:50 -05:00
if pooled is not None :
2023-12-08 02:35:45 -05:00
first_pooled = pooled [ 0 : 1 ] . to ( model_management . intermediate_device ( ) )
2023-07-01 15:07:39 -04:00
else :
2023-11-06 13:43:50 -05:00
first_pooled = pooled
2023-07-01 15:07:39 -04:00
output = [ ]
2023-11-06 13:43:50 -05:00
for k in range ( 0 , sections ) :
2023-07-06 02:43:40 -04:00
z = out [ k : k + 1 ]
2023-11-06 13:43:50 -05:00
if has_weights :
z_empty = out [ - 1 ]
for i in range ( len ( z ) ) :
for j in range ( len ( z [ i ] ) ) :
weight = token_weight_pairs [ k ] [ j ] [ 1 ]
if weight != 1.0 :
z [ i ] [ j ] = ( z [ i ] [ j ] - z_empty [ j ] ) * weight + z_empty [ j ]
2023-07-01 15:07:39 -04:00
output . append ( z )
2023-01-03 01:53:32 -05:00
if ( len ( output ) == 0 ) :
2024-07-10 19:31:22 -04:00
r = ( out [ - 1 : ] . to ( model_management . intermediate_device ( ) ) , first_pooled )
else :
r = ( torch . cat ( output , dim = - 2 ) . to ( model_management . intermediate_device ( ) ) , first_pooled )
2024-07-10 20:06:50 -04:00
if len ( o ) > 2 :
extra = { }
for k in o [ 2 ] :
v = o [ 2 ] [ k ]
if k == " attention_mask " :
v = v [ : sections ] . flatten ( ) . unsqueeze ( dim = 0 ) . to ( model_management . intermediate_device ( ) )
extra [ k ] = v
r = r + ( extra , )
2024-07-10 19:31:22 -04:00
return r
2023-01-03 01:53:32 -05:00
2023-10-27 15:54:04 -04:00
class SDClipModel ( torch . nn . Module , ClipTokenWeightEncoder ) :
2023-01-03 01:53:32 -05:00
LAYERS = [
" last " ,
" pooled " ,
2025-04-15 10:32:21 -04:00
" hidden " ,
" all "
2023-01-03 01:53:32 -05:00
]
2024-10-12 07:16:21 -04:00
def __init__ ( self , device = " cpu " , max_length = 77 ,
2023-12-06 15:55:09 -05:00
freeze = True , layer = " last " , layer_idx = None , textmodel_json_config = None , dtype = None , model_class = comfy . clip_model . CLIPTextModel ,
2024-06-09 16:41:04 -04:00
special_tokens = { " start " : 49406 , " end " : 49407 , " pad " : 49407 } , layer_norm_hidden_state = True , enable_attention_masks = False , zero_out_masked = False ,
2024-08-17 10:15:13 -04:00
return_projected_pooled = True , return_attention_masks = False , model_options = { } ) : # clip-vit-base-patch32
2023-01-03 01:53:32 -05:00
super ( ) . __init__ ( )
assert layer in self . LAYERS
2023-12-06 15:55:09 -05:00
if textmodel_json_config is None :
textmodel_json_config = os . path . join ( os . path . dirname ( os . path . realpath ( __file__ ) ) , " sd1_clip_config.json " )
2025-04-15 10:32:21 -04:00
if " model_name " not in model_options :
model_options = { * * model_options , " model_name " : " clip_l " }
2023-12-06 15:55:09 -05:00
2024-12-10 09:44:13 -05:00
if isinstance ( textmodel_json_config , dict ) :
config = textmodel_json_config
else :
with open ( textmodel_json_config ) as f :
config = json . load ( f )
2023-12-06 15:55:09 -05:00
2025-04-15 10:32:21 -04:00
te_model_options = model_options . get ( " {} _model_config " . format ( model_options . get ( " model_name " , " " ) ) , { } )
for k , v in te_model_options . items ( ) :
config [ k ] = v
2024-08-17 10:15:13 -04:00
operations = model_options . get ( " custom_operations " , None )
2024-10-20 22:27:00 -04:00
scaled_fp8 = None
2024-08-17 10:15:13 -04:00
if operations is None :
2024-10-20 22:27:00 -04:00
scaled_fp8 = model_options . get ( " scaled_fp8 " , None )
if scaled_fp8 is not None :
operations = comfy . ops . scaled_fp8_ops ( fp8_matrix_mult = False , override_dtype = scaled_fp8 )
else :
operations = comfy . ops . manual_cast
2024-08-17 10:15:13 -04:00
self . operations = operations
2024-07-31 01:32:35 -04:00
self . transformer = model_class ( config , dtype , device , self . operations )
2024-10-20 22:27:00 -04:00
if scaled_fp8 is not None :
self . transformer . scaled_fp8 = torch . nn . Parameter ( torch . tensor ( [ ] , dtype = scaled_fp8 ) )
2023-12-06 15:55:09 -05:00
self . num_layers = self . transformer . num_layers
2023-09-11 21:49:56 -04:00
2023-01-03 01:53:32 -05:00
self . max_length = max_length
if freeze :
self . freeze ( )
self . layer = layer
self . layer_idx = None
2023-11-06 13:43:50 -05:00
self . special_tokens = special_tokens
2024-02-25 01:41:08 -05:00
2023-08-24 22:20:30 -04:00
self . logit_scale = torch . nn . Parameter ( torch . tensor ( 4.6055 ) )
2024-02-16 13:29:04 -05:00
self . enable_attention_masks = enable_attention_masks
2024-06-09 16:41:04 -04:00
self . zero_out_masked = zero_out_masked
2023-08-24 22:20:30 -04:00
2023-11-06 13:43:50 -05:00
self . layer_norm_hidden_state = layer_norm_hidden_state
2024-02-25 14:49:13 -05:00
self . return_projected_pooled = return_projected_pooled
2024-07-10 19:31:22 -04:00
self . return_attention_masks = return_attention_masks
2024-02-25 07:20:31 -05:00
2023-01-03 01:53:32 -05:00
if layer == " hidden " :
assert layer_idx is not None
2023-12-06 15:55:09 -05:00
assert abs ( layer_idx ) < self . num_layers
2024-02-25 07:20:31 -05:00
self . set_clip_options ( { " layer " : layer_idx } )
self . options_default = ( self . layer , self . layer_idx , self . return_projected_pooled )
2023-01-03 01:53:32 -05:00
def freeze ( self ) :
self . transformer = self . transformer . eval ( )
#self.train = disabled_train
for param in self . parameters ( ) :
param . requires_grad = False
2024-02-25 07:20:31 -05:00
def set_clip_options ( self , options ) :
layer_idx = options . get ( " layer " , self . layer_idx )
self . return_projected_pooled = options . get ( " projected_pooled " , self . return_projected_pooled )
2025-04-15 10:32:21 -04:00
if self . layer == " all " :
pass
elif layer_idx is None or abs ( layer_idx ) > self . num_layers :
2023-01-03 01:53:32 -05:00
self . layer = " last "
else :
self . layer = " hidden "
self . layer_idx = layer_idx
2024-02-25 07:20:31 -05:00
def reset_clip_options ( self ) :
self . layer = self . options_default [ 0 ]
self . layer_idx = self . options_default [ 1 ]
self . return_projected_pooled = self . options_default [ 2 ]
2023-07-15 01:10:33 -04:00
2025-03-05 17:34:38 -05:00
def process_tokens ( self , tokens , device ) :
end_token = self . special_tokens . get ( " end " , None )
if end_token is None :
cmp_token = self . special_tokens . get ( " pad " , - 1 )
else :
cmp_token = end_token
embeds_out = [ ]
attention_masks = [ ]
num_tokens = [ ]
2023-01-29 18:46:44 -05:00
for x in tokens :
2025-03-05 17:34:38 -05:00
attention_mask = [ ]
2023-01-29 18:46:44 -05:00
tokens_temp = [ ]
2025-03-05 17:34:38 -05:00
other_embeds = [ ]
eos = False
index = 0
2023-01-29 18:46:44 -05:00
for y in x :
2024-06-19 22:42:41 +02:00
if isinstance ( y , numbers . Integral ) :
2025-03-05 17:34:38 -05:00
if eos :
attention_mask . append ( 0 )
2023-04-04 11:49:29 -04:00
else :
2025-03-05 17:34:38 -05:00
attention_mask . append ( 1 )
token = int ( y )
tokens_temp + = [ token ]
if not eos and token == cmp_token :
if end_token is None :
attention_mask [ - 1 ] = 0
eos = True
else :
other_embeds . append ( ( index , y ) )
index + = 1
tokens_embed = torch . tensor ( [ tokens_temp ] , device = device , dtype = torch . long )
tokens_embed = self . transformer . get_input_embeddings ( ) ( tokens_embed , out_dtype = torch . float32 )
index = 0
pad_extra = 0
for o in other_embeds :
2025-03-06 00:24:43 -05:00
emb = o [ 1 ]
if torch . is_tensor ( emb ) :
emb = { " type " : " embedding " , " data " : emb }
emb_type = emb . get ( " type " , None )
if emb_type == " embedding " :
emb = emb . get ( " data " , None )
else :
if hasattr ( self . transformer , " preprocess_embed " ) :
emb = self . transformer . preprocess_embed ( emb , device = device )
else :
emb = None
if emb is None :
index + = - 1
continue
2025-03-05 17:34:38 -05:00
ind = index + o [ 0 ]
2025-03-06 00:24:43 -05:00
emb = emb . view ( 1 , - 1 , emb . shape [ - 1 ] ) . to ( device = device , dtype = torch . float32 )
2025-03-05 17:34:38 -05:00
emb_shape = emb . shape [ 1 ]
if emb . shape [ - 1 ] == tokens_embed . shape [ - 1 ] :
tokens_embed = torch . cat ( [ tokens_embed [ : , : ind ] , emb , tokens_embed [ : , ind : ] ] , dim = 1 )
attention_mask = attention_mask [ : ind ] + [ 1 ] * emb_shape + attention_mask [ ind : ]
index + = emb_shape - 1
else :
index + = - 1
pad_extra + = emb_shape
logging . warning ( " WARNING: shape mismatch when trying to apply embedding, embedding will be ignored {} != {} " . format ( emb . shape [ - 1 ] , tokens_embed . shape [ - 1 ] ) )
2023-01-29 18:46:44 -05:00
2025-03-05 17:34:38 -05:00
if pad_extra > 0 :
padd_embed = self . transformer . get_input_embeddings ( ) ( torch . tensor ( [ [ self . special_tokens [ " pad " ] ] * pad_extra ] , device = device , dtype = torch . long ) , out_dtype = torch . float32 )
tokens_embed = torch . cat ( [ tokens_embed , padd_embed ] , dim = 1 )
2025-03-06 13:31:40 -05:00
attention_mask = attention_mask + [ 0 ] * pad_extra
2024-12-10 23:07:26 -05:00
2025-03-05 17:34:38 -05:00
embeds_out . append ( tokens_embed )
attention_masks . append ( attention_mask )
num_tokens . append ( sum ( attention_mask ) )
return torch . cat ( embeds_out ) , torch . tensor ( attention_masks , device = device , dtype = torch . long ) , num_tokens
def forward ( self , tokens ) :
device = self . transformer . get_input_embeddings ( ) . weight . device
embeds , attention_mask , num_tokens = self . process_tokens ( tokens , device )
2023-12-10 23:00:54 -05:00
2024-07-05 23:48:17 -04:00
attention_mask_model = None
if self . enable_attention_masks :
attention_mask_model = attention_mask
2025-04-15 10:32:21 -04:00
if self . layer == " all " :
intermediate_output = " all "
else :
intermediate_output = self . layer_idx
outputs = self . transformer ( None , attention_mask_model , embeds = embeds , num_tokens = num_tokens , intermediate_output = intermediate_output , final_layer_norm_intermediate = self . layer_norm_hidden_state , dtype = torch . float32 )
2023-12-10 23:00:54 -05:00
if self . layer == " last " :
2024-06-09 16:41:04 -04:00
z = outputs [ 0 ] . float ( )
2023-01-03 01:53:32 -05:00
else :
2024-06-09 16:41:04 -04:00
z = outputs [ 1 ] . float ( )
2024-07-05 23:48:17 -04:00
if self . zero_out_masked :
2024-06-09 16:41:04 -04:00
z * = attention_mask . unsqueeze ( - 1 ) . float ( )
2023-12-10 23:00:54 -05:00
2024-02-25 07:20:31 -05:00
pooled_output = None
if len ( outputs ) > = 3 :
if not self . return_projected_pooled and len ( outputs ) > = 4 and outputs [ 3 ] is not None :
pooled_output = outputs [ 3 ] . float ( )
elif outputs [ 2 ] is not None :
pooled_output = outputs [ 2 ] . float ( )
2023-12-10 23:00:54 -05:00
2024-07-10 20:06:50 -04:00
extra = { }
2024-07-10 19:31:22 -04:00
if self . return_attention_masks :
2024-07-10 20:06:50 -04:00
extra [ " attention_mask " ] = attention_mask
if len ( extra ) > 0 :
return z , pooled_output , extra
2024-07-10 19:31:22 -04:00
2024-06-09 16:41:04 -04:00
return z , pooled_output
2023-01-03 01:53:32 -05:00
def encode ( self , tokens ) :
return self ( tokens )
2023-06-25 01:40:38 -04:00
def load_sd ( self , sd ) :
return self . transformer . load_state_dict ( sd , strict = False )
2023-01-03 01:53:32 -05:00
def parse_parentheses ( string ) :
result = [ ]
current_item = " "
nesting_level = 0
for char in string :
if char == " ( " :
if nesting_level == 0 :
if current_item :
result . append ( current_item )
current_item = " ( "
else :
current_item = " ( "
else :
current_item + = char
nesting_level + = 1
elif char == " ) " :
nesting_level - = 1
if nesting_level == 0 :
result . append ( current_item + " ) " )
current_item = " "
else :
current_item + = char
else :
current_item + = char
if current_item :
result . append ( current_item )
return result
def token_weights ( string , current_weight ) :
a = parse_parentheses ( string )
out = [ ]
for x in a :
weight = current_weight
if len ( x ) > = 2 and x [ - 1 ] == ' ) ' and x [ 0 ] == ' ( ' :
x = x [ 1 : - 1 ]
xx = x . rfind ( " : " )
weight * = 1.1
if xx > 0 :
try :
weight = float ( x [ xx + 1 : ] )
x = x [ : xx ]
except :
pass
out + = token_weights ( x , weight )
else :
out + = [ ( x , current_weight ) ]
return out
def escape_important ( text ) :
text = text . replace ( " \\ ) " , " \0 \1 " )
text = text . replace ( " \\ ( " , " \0 \2 " )
return text
def unescape_important ( text ) :
text = text . replace ( " \0 \1 " , " ) " )
text = text . replace ( " \0 \2 " , " ( " )
return text
2023-04-14 15:33:43 -04:00
def safe_load_embed_zip ( embed_path ) :
with zipfile . ZipFile ( embed_path ) as myzip :
names = list ( filter ( lambda a : " data/ " in a , myzip . namelist ( ) ) )
names . reverse ( )
for n in names :
with myzip . open ( n ) as myfile :
data = myfile . read ( )
number = len ( data ) / / 4
length_embed = 1024 #sd2.x
if number < 768 :
continue
if number % 768 == 0 :
length_embed = 768 #sd1.x
num_embeds = number / / length_embed
embed = torch . frombuffer ( data , dtype = torch . float )
out = embed . reshape ( ( num_embeds , length_embed ) ) . clone ( )
del embed
return out
2023-05-05 01:28:48 -04:00
def expand_directory_list ( directories ) :
dirs = set ( )
for x in directories :
dirs . add ( x )
for root , subdir , file in os . walk ( x , followlinks = True ) :
dirs . add ( root )
return list ( dirs )
2023-04-14 15:33:43 -04:00
2024-08-07 13:30:45 -04:00
def bundled_embed ( embed , prefix , suffix ) : #bundled embedding in lora format
2024-08-07 03:45:25 -04:00
out_list = [ ]
2024-08-07 13:30:45 -04:00
for k in embed :
if k . startswith ( prefix ) and k . endswith ( suffix ) :
out_list . append ( embed [ k ] )
if len ( out_list ) == 0 :
return None
2024-08-07 03:45:25 -04:00
return torch . cat ( out_list , dim = 0 )
2023-07-10 10:28:38 -04:00
def load_embed ( embedding_name , embedding_directory , embedding_size , embed_key = None ) :
2023-03-18 03:08:43 -04:00
if isinstance ( embedding_directory , str ) :
embedding_directory = [ embedding_directory ]
2023-05-05 01:28:48 -04:00
embedding_directory = expand_directory_list ( embedding_directory )
2023-03-18 03:08:43 -04:00
valid_file = None
for embed_dir in embedding_directory :
2023-10-27 02:42:14 -04:00
embed_path = os . path . abspath ( os . path . join ( embed_dir , embedding_name ) )
embed_dir = os . path . abspath ( embed_dir )
try :
if os . path . commonpath ( ( embed_dir , embed_path ) ) != embed_dir :
continue
except :
continue
2023-03-18 03:08:43 -04:00
if not os . path . isfile ( embed_path ) :
extensions = [ ' .safetensors ' , ' .pt ' , ' .bin ' ]
for x in extensions :
t = embed_path + x
if os . path . isfile ( t ) :
valid_file = t
break
2023-01-29 18:46:44 -05:00
else :
2023-03-18 03:08:43 -04:00
valid_file = embed_path
if valid_file is not None :
break
if valid_file is None :
return None
embed_path = valid_file
2023-01-29 18:46:44 -05:00
2023-04-14 15:33:43 -04:00
embed_out = None
2023-04-14 13:54:00 -04:00
try :
if embed_path . lower ( ) . endswith ( " .safetensors " ) :
import safetensors . torch
embed = safetensors . torch . load_file ( embed_path , device = " cpu " )
2023-02-19 16:59:03 -05:00
else :
2025-01-15 04:32:23 -05:00
try :
embed = torch . load ( embed_path , weights_only = True , map_location = " cpu " )
except :
embed_out = safe_load_embed_zip ( embed_path )
2024-12-12 14:59:16 -08:00
except Exception :
2024-03-10 11:37:08 -04:00
logging . warning ( " {} \n \n error loading embedding, skipping loading: {} " . format ( traceback . format_exc ( ) , embedding_name ) )
2023-04-14 13:54:00 -04:00
return None
2023-04-14 15:33:43 -04:00
if embed_out is None :
if ' string_to_param ' in embed :
values = embed [ ' string_to_param ' ] . values ( )
2023-06-22 13:03:50 -04:00
embed_out = next ( iter ( values ) )
elif isinstance ( embed , list ) :
out_list = [ ]
for x in range ( len ( embed ) ) :
for k in embed [ x ] :
t = embed [ x ] [ k ]
if t . shape [ - 1 ] != embedding_size :
continue
out_list . append ( t . reshape ( - 1 , t . shape [ - 1 ] ) )
embed_out = torch . cat ( out_list , dim = 0 )
2023-07-10 10:28:38 -04:00
elif embed_key is not None and embed_key in embed :
embed_out = embed [ embed_key ]
2023-04-14 15:33:43 -04:00
else :
2024-08-07 13:30:45 -04:00
embed_out = bundled_embed ( embed , ' bundle_emb. ' , ' .string_to_param.* ' )
if embed_out is None :
embed_out = bundled_embed ( embed , ' bundle_emb. ' , ' . {} ' . format ( embed_key ) )
if embed_out is None :
values = embed . values ( )
embed_out = next ( iter ( values ) )
2023-04-14 15:33:43 -04:00
return embed_out
2023-01-29 18:46:44 -05:00
2023-10-27 15:54:04 -04:00
class SDTokenizer :
2025-04-25 16:36:00 -07:00
def __init__ ( self , tokenizer_path = None , max_length = 77 , pad_with_end = True , embedding_directory = None , embedding_size = 768 , embedding_key = ' clip_l ' , tokenizer_class = CLIPTokenizer , has_start_token = True , has_end_token = True , pad_to_max_length = True , min_length = None , pad_token = None , end_token = None , min_padding = None , tokenizer_data = { } , tokenizer_args = { } ) :
2023-01-03 01:53:32 -05:00
if tokenizer_path is None :
tokenizer_path = os . path . join ( os . path . dirname ( os . path . realpath ( __file__ ) ) , " sd1_tokenizer " )
2025-02-04 03:56:00 -05:00
self . tokenizer = tokenizer_class . from_pretrained ( tokenizer_path , * * tokenizer_args )
2025-04-15 10:32:21 -04:00
self . max_length = tokenizer_data . get ( " {} _max_length " . format ( embedding_key ) , max_length )
2024-02-26 21:36:37 -05:00
self . min_length = min_length
2024-12-10 09:44:13 -05:00
self . end_token = None
2025-04-25 16:36:00 -07:00
self . min_padding = min_padding
2023-01-29 18:46:44 -05:00
2023-01-03 01:53:32 -05:00
empty = self . tokenizer ( ' ' ) [ " input_ids " ]
2024-12-17 16:18:35 -05:00
self . tokenizer_adds_end_token = has_end_token
2023-11-06 01:09:18 -05:00
if has_start_token :
self . tokens_start = 1
self . start_token = empty [ 0 ]
2024-12-17 16:18:35 -05:00
if end_token is not None :
self . end_token = end_token
else :
if has_end_token :
2024-12-16 19:35:40 -05:00
self . end_token = empty [ 1 ]
2023-11-06 01:09:18 -05:00
else :
self . tokens_start = 0
self . start_token = None
2024-12-16 19:35:40 -05:00
if end_token is not None :
self . end_token = end_token
else :
2024-12-10 09:44:13 -05:00
self . end_token = empty [ 0 ]
2024-07-06 00:06:49 -04:00
if pad_token is not None :
self . pad_token = pad_token
elif pad_with_end :
self . pad_token = self . end_token
else :
self . pad_token = 0
2023-01-03 01:53:32 -05:00
self . pad_with_end = pad_with_end
2023-11-06 01:09:18 -05:00
self . pad_to_max_length = pad_to_max_length
2023-01-03 01:53:32 -05:00
vocab = self . tokenizer . get_vocab ( )
self . inv_vocab = { v : k for k , v in vocab . items ( ) }
2023-01-29 18:46:44 -05:00
self . embedding_directory = embedding_directory
self . max_word_length = 8
2023-04-13 22:01:01 +02:00
self . embedding_identifier = " embedding: "
2023-06-22 13:03:50 -04:00
self . embedding_size = embedding_size
2023-07-10 10:28:38 -04:00
self . embedding_key = embedding_key
2023-04-13 22:01:01 +02:00
2023-04-14 21:02:45 +02:00
def _try_get_embedding ( self , embedding_name : str ) :
2023-04-13 22:01:01 +02:00
'''
Takes a potential embedding name and tries to retrieve it .
Returns a Tuple consisting of the embedding and any leftover string , embedding can be None .
'''
2024-12-17 16:18:35 -05:00
split_embed = embedding_name . split ( )
2024-12-10 09:44:13 -05:00
embedding_name = split_embed [ 0 ]
leftover = ' ' . join ( split_embed [ 1 : ] )
2023-07-10 10:28:38 -04:00
embed = load_embed ( embedding_name , self . embedding_directory , self . embedding_size , self . embedding_key )
2023-04-13 22:01:01 +02:00
if embed is None :
stripped = embedding_name . strip ( ' , ' )
if len ( stripped ) < len ( embedding_name ) :
2023-07-10 10:28:38 -04:00
embed = load_embed ( stripped , self . embedding_directory , self . embedding_size , self . embedding_key )
2024-12-10 09:44:13 -05:00
return ( embed , " {} {} " . format ( embedding_name [ len ( stripped ) : ] , leftover ) )
return ( embed , leftover )
2023-04-13 22:01:01 +02:00
2025-04-25 16:36:00 -07:00
def tokenize_with_weights ( self , text : str , return_word_ids = False , tokenizer_options = { } , * * kwargs ) :
2023-04-13 22:01:01 +02:00
'''
Takes a prompt and converts it to a list of ( token , weight , word id ) elements .
Tokens can both be integer tokens and pre computed CLIP tensors .
Word id values are unique per word and embedding , where the id 0 is reserved for non word tokens .
Returned list has the dimensions NxM where M is the input size of CLIP
'''
2025-04-25 16:36:00 -07:00
min_length = tokenizer_options . get ( " {} _min_length " . format ( self . embedding_key ) , self . min_length )
min_padding = tokenizer_options . get ( " {} _min_padding " . format ( self . embedding_key ) , self . min_padding )
2023-01-03 01:53:32 -05:00
text = escape_important ( text )
parsed_weights = token_weights ( text , 1.0 )
2024-12-17 16:18:35 -05:00
# tokenize words
2023-01-03 01:53:32 -05:00
tokens = [ ]
2023-04-13 22:01:01 +02:00
for weighted_segment , weight in parsed_weights :
2024-12-17 16:18:35 -05:00
to_tokenize = unescape_important ( weighted_segment )
split = re . split ( ' {0} | \n {0} ' . format ( self . embedding_identifier ) , to_tokenize )
2024-12-10 09:44:13 -05:00
to_tokenize = [ split [ 0 ] ]
for i in range ( 1 , len ( split ) ) :
to_tokenize . append ( " {} {} " . format ( self . embedding_identifier , split [ i ] ) )
2023-04-13 22:01:01 +02:00
to_tokenize = [ x for x in to_tokenize if x != " " ]
for word in to_tokenize :
2024-12-17 16:18:35 -05:00
# if we find an embedding, deal with the embedding
2023-04-13 22:01:01 +02:00
if word . startswith ( self . embedding_identifier ) and self . embedding_directory is not None :
2023-04-14 21:02:45 +02:00
embedding_name = word [ len ( self . embedding_identifier ) : ] . strip ( ' \n ' )
embed , leftover = self . _try_get_embedding ( embedding_name )
2023-02-19 02:50:48 -05:00
if embed is None :
2024-03-10 11:37:08 -04:00
logging . warning ( f " warning, embedding: { embedding_name } does not exist, ignoring " )
2023-04-13 22:01:01 +02:00
else :
2023-01-29 18:46:44 -05:00
if len ( embed . shape ) == 1 :
2023-04-13 22:01:01 +02:00
tokens . append ( [ ( embed , weight ) ] )
2023-01-29 18:46:44 -05:00
else :
2023-04-13 22:01:01 +02:00
tokens . append ( [ ( embed [ x ] , weight ) for x in range ( embed . shape [ 0 ] ) ] )
#if we accidentally have leftover text, continue parsing using leftover, else move on to next word
if leftover != " " :
word = leftover
2023-02-19 02:50:48 -05:00
else :
2023-04-13 22:01:01 +02:00
continue
2024-12-10 09:44:13 -05:00
end = 999999999999
2024-12-17 16:18:35 -05:00
if self . tokenizer_adds_end_token :
2024-12-10 09:44:13 -05:00
end = - 1
2023-04-13 22:01:01 +02:00
#parse word
2024-12-10 09:44:13 -05:00
tokens . append ( [ ( t , weight ) for t in self . tokenizer ( word ) [ " input_ids " ] [ self . tokens_start : end ] ] )
2023-04-15 18:46:58 -04:00
2023-04-13 22:01:01 +02:00
#reshape token array to CLIP input size
batched_tokens = [ ]
2023-11-06 01:09:18 -05:00
batch = [ ]
if self . start_token is not None :
batch . append ( ( self . start_token , 1.0 , 0 ) )
2023-04-13 22:01:01 +02:00
batched_tokens . append ( batch )
for i , t_group in enumerate ( tokens ) :
2023-04-14 21:02:45 +02:00
#determine if we're going to try and keep the tokens in a single batch
is_large = len ( t_group ) > = self . max_word_length
2024-12-10 23:07:26 -05:00
if self . end_token is not None :
has_end_token = 1
else :
has_end_token = 0
2023-04-15 19:38:21 +02:00
2023-04-14 21:02:45 +02:00
while len ( t_group ) > 0 :
2024-12-10 23:07:26 -05:00
if len ( t_group ) + len ( batch ) > self . max_length - has_end_token :
remaining_length = self . max_length - len ( batch ) - has_end_token
2023-04-15 19:38:21 +02:00
#break word in two and add end token
2023-04-14 21:02:45 +02:00
if is_large :
batch . extend ( [ ( t , w , i + 1 ) for t , w in t_group [ : remaining_length ] ] )
2024-12-10 09:44:13 -05:00
if self . end_token is not None :
batch . append ( ( self . end_token , 1.0 , 0 ) )
2023-04-14 21:02:45 +02:00
t_group = t_group [ remaining_length : ]
2023-04-15 19:38:21 +02:00
#add end token and pad
2023-02-19 02:50:48 -05:00
else :
2024-12-10 09:44:13 -05:00
if self . end_token is not None :
batch . append ( ( self . end_token , 1.0 , 0 ) )
2023-11-06 01:09:18 -05:00
if self . pad_to_max_length :
2024-07-06 00:06:49 -04:00
batch . extend ( [ ( self . pad_token , 1.0 , 0 ) ] * ( remaining_length ) )
2023-04-15 19:38:21 +02:00
#start new batch
2023-11-06 01:09:18 -05:00
batch = [ ]
if self . start_token is not None :
batch . append ( ( self . start_token , 1.0 , 0 ) )
2023-04-15 18:46:58 -04:00
batched_tokens . append ( batch )
2023-04-13 22:01:01 +02:00
else :
2023-04-14 21:02:45 +02:00
batch . extend ( [ ( t , w , i + 1 ) for t , w in t_group ] )
t_group = [ ]
2023-04-15 18:46:58 -04:00
2023-04-13 22:01:01 +02:00
#fill last batch
2024-12-10 09:44:13 -05:00
if self . end_token is not None :
batch . append ( ( self . end_token , 1.0 , 0 ) )
2025-04-25 16:36:00 -07:00
if min_padding is not None :
batch . extend ( [ ( self . pad_token , 1.0 , 0 ) ] * min_padding )
if self . pad_to_max_length and len ( batch ) < self . max_length :
2024-07-06 00:06:49 -04:00
batch . extend ( [ ( self . pad_token , 1.0 , 0 ) ] * ( self . max_length - len ( batch ) ) )
2025-04-25 16:36:00 -07:00
if min_length is not None and len ( batch ) < min_length :
batch . extend ( [ ( self . pad_token , 1.0 , 0 ) ] * ( min_length - len ( batch ) ) )
2023-01-03 01:53:32 -05:00
2023-04-14 21:16:55 +02:00
if not return_word_ids :
batched_tokens = [ [ ( t , w ) for t , w , _ in x ] for x in batched_tokens ]
2023-01-03 01:53:32 -05:00
2023-04-13 22:01:01 +02:00
return batched_tokens
2023-01-03 01:53:32 -05:00
def untokenize ( self , token_weight_pair ) :
return list ( map ( lambda a : ( a , self . inv_vocab [ a [ 0 ] ] ) , token_weight_pair ) )
2023-10-27 15:54:04 -04:00
2024-07-25 10:52:09 -04:00
def state_dict ( self ) :
return { }
2023-10-27 15:54:04 -04:00
class SD1Tokenizer :
2025-02-04 03:56:00 -05:00
def __init__ ( self , embedding_directory = None , tokenizer_data = { } , clip_name = " l " , tokenizer = SDTokenizer , name = None ) :
if name is not None :
self . clip_name = name
self . clip = " {} " . format ( self . clip_name )
else :
self . clip_name = clip_name
self . clip = " clip_ {} " . format ( self . clip_name )
2024-09-15 07:59:18 -04:00
tokenizer = tokenizer_data . get ( " {} _tokenizer_class " . format ( self . clip ) , tokenizer )
2024-07-24 16:43:53 -04:00
setattr ( self , self . clip , tokenizer ( embedding_directory = embedding_directory , tokenizer_data = tokenizer_data ) )
2023-10-27 15:54:04 -04:00
2025-03-04 09:26:05 -05:00
def tokenize_with_weights ( self , text : str , return_word_ids = False , * * kwargs ) :
2023-10-27 15:54:04 -04:00
out = { }
2025-04-25 16:36:00 -07:00
out [ self . clip_name ] = getattr ( self , self . clip ) . tokenize_with_weights ( text , return_word_ids , * * kwargs )
2023-10-27 15:54:04 -04:00
return out
def untokenize ( self , token_weight_pair ) :
return getattr ( self , self . clip ) . untokenize ( token_weight_pair )
2024-07-25 10:52:09 -04:00
def state_dict ( self ) :
2025-02-04 03:56:00 -05:00
return getattr ( self , self . clip ) . state_dict ( )
2023-10-27 15:54:04 -04:00
2024-08-20 10:00:16 -04:00
class SD1CheckpointClipModel ( SDClipModel ) :
def __init__ ( self , device = " cpu " , dtype = None , model_options = { } ) :
super ( ) . __init__ ( device = device , return_projected_pooled = False , dtype = dtype , model_options = model_options )
2023-10-27 15:54:04 -04:00
class SD1ClipModel ( torch . nn . Module ) :
2024-08-20 10:00:16 -04:00
def __init__ ( self , device = " cpu " , dtype = None , model_options = { } , clip_name = " l " , clip_model = SD1CheckpointClipModel , name = None , * * kwargs ) :
2023-10-27 15:54:04 -04:00
super ( ) . __init__ ( )
2024-07-08 08:48:38 -04:00
if name is not None :
self . clip_name = name
self . clip = " {} " . format ( self . clip_name )
else :
self . clip_name = clip_name
self . clip = " clip_ {} " . format ( self . clip_name )
2024-09-15 07:59:18 -04:00
clip_model = model_options . get ( " {} _class " . format ( self . clip ) , clip_model )
2025-04-15 10:32:21 -04:00
model_options = { * * model_options , " model_name " : self . clip }
2024-08-17 10:15:13 -04:00
setattr ( self , self . clip , clip_model ( device = device , dtype = dtype , model_options = model_options , * * kwargs ) )
2023-10-27 15:54:04 -04:00
2024-06-11 17:03:26 -04:00
self . dtypes = set ( )
if dtype is not None :
self . dtypes . add ( dtype )
2024-02-25 07:20:31 -05:00
def set_clip_options ( self , options ) :
getattr ( self , self . clip ) . set_clip_options ( options )
2023-10-27 15:54:04 -04:00
2024-02-25 07:20:31 -05:00
def reset_clip_options ( self ) :
getattr ( self , self . clip ) . reset_clip_options ( )
2023-10-27 15:54:04 -04:00
def encode_token_weights ( self , token_weight_pairs ) :
token_weight_pairs = token_weight_pairs [ self . clip_name ]
2024-07-10 20:06:50 -04:00
out = getattr ( self , self . clip ) . encode_token_weights ( token_weight_pairs )
return out
2023-10-27 15:54:04 -04:00
def load_sd ( self , sd ) :
return getattr ( self , self . clip ) . load_sd ( sd )