ComfyUI/comfy/text_encoders/spiece_tokenizer.py

32 lines
1.1 KiB
Python
Raw Normal View History

import torch
2024-07-12 05:08:45 +00:00
class SPieceTokenizer:
2024-07-12 05:08:45 +00:00
@staticmethod
2025-02-04 08:56:00 +00:00
def from_pretrained(path, **kwargs):
return SPieceTokenizer(path, **kwargs)
2024-07-12 05:08:45 +00:00
2025-02-04 08:56:00 +00:00
def __init__(self, tokenizer_path, add_bos=False, add_eos=True):
self.add_bos = add_bos
self.add_eos = add_eos
2024-07-12 05:08:45 +00:00
import sentencepiece
if torch.is_tensor(tokenizer_path):
tokenizer_path = tokenizer_path.numpy().tobytes()
if isinstance(tokenizer_path, bytes):
2025-02-04 08:56:00 +00:00
self.tokenizer = sentencepiece.SentencePieceProcessor(model_proto=tokenizer_path, add_bos=self.add_bos, add_eos=self.add_eos)
else:
2025-02-04 08:56:00 +00:00
self.tokenizer = sentencepiece.SentencePieceProcessor(model_file=tokenizer_path, add_bos=self.add_bos, add_eos=self.add_eos)
2024-07-12 05:08:45 +00:00
def get_vocab(self):
out = {}
for i in range(self.tokenizer.get_piece_size()):
out[self.tokenizer.id_to_piece(i)] = i
return out
def __call__(self, string):
out = self.tokenizer.encode(string)
return {"input_ids": out}
def serialize_model(self):
return torch.ByteTensor(list(self.tokenizer.serialized_model_proto()))