ComfyUI/comfy/text_encoders/ace_text_cleaners.py

396 lines
13 KiB
Python
Raw Normal View History

# basic text cleaners for the ACE step model
# I didn't copy the ones from the reference code because I didn't want to deal with the dependencies
# TODO: more languages than english?
import re
def japanese_to_romaji(japanese_text):
"""
Convert Japanese hiragana and katakana to romaji (Latin alphabet representation).
Args:
japanese_text (str): Text containing hiragana and/or katakana characters
Returns:
str: The romaji (Latin alphabet) equivalent
"""
# Dictionary mapping kana characters to their romaji equivalents
kana_map = {
# Katakana characters
'': 'a', '': 'i', '': 'u', '': 'e', '': 'o',
'': 'ka', '': 'ki', '': 'ku', '': 'ke', '': 'ko',
'': 'sa', '': 'shi', '': 'su', '': 'se', '': 'so',
'': 'ta', '': 'chi', '': 'tsu', '': 'te', '': 'to',
'': 'na', '': 'ni', '': 'nu', '': 'ne', '': 'no',
'': 'ha', '': 'hi', '': 'fu', '': 'he', '': 'ho',
'': 'ma', '': 'mi', '': 'mu', '': 'me', '': 'mo',
'': 'ya', '': 'yu', '': 'yo',
'': 'ra', '': 'ri', '': 'ru', '': 're', '': 'ro',
'': 'wa', '': 'wo', '': 'n',
# Katakana voiced consonants
'': 'ga', '': 'gi', '': 'gu', '': 'ge', '': 'go',
'': 'za', '': 'ji', '': 'zu', '': 'ze', '': 'zo',
'': 'da', '': 'ji', '': 'zu', '': 'de', '': 'do',
'': 'ba', '': 'bi', '': 'bu', '': 'be', '': 'bo',
'': 'pa', '': 'pi', '': 'pu', '': 'pe', '': 'po',
# Katakana combinations
'キャ': 'kya', 'キュ': 'kyu', 'キョ': 'kyo',
'シャ': 'sha', 'シュ': 'shu', 'ショ': 'sho',
'チャ': 'cha', 'チュ': 'chu', 'チョ': 'cho',
'ニャ': 'nya', 'ニュ': 'nyu', 'ニョ': 'nyo',
'ヒャ': 'hya', 'ヒュ': 'hyu', 'ヒョ': 'hyo',
'ミャ': 'mya', 'ミュ': 'myu', 'ミョ': 'myo',
'リャ': 'rya', 'リュ': 'ryu', 'リョ': 'ryo',
'ギャ': 'gya', 'ギュ': 'gyu', 'ギョ': 'gyo',
'ジャ': 'ja', 'ジュ': 'ju', 'ジョ': 'jo',
'ビャ': 'bya', 'ビュ': 'byu', 'ビョ': 'byo',
'ピャ': 'pya', 'ピュ': 'pyu', 'ピョ': 'pyo',
# Katakana small characters and special cases
'': '', # Small tsu (doubles the following consonant)
'': 'ya', '': 'yu', '': 'yo',
# Katakana extras
'': 'vu', 'ファ': 'fa', 'フィ': 'fi', 'フェ': 'fe', 'フォ': 'fo',
'ウィ': 'wi', 'ウェ': 'we', 'ウォ': 'wo',
# Hiragana characters
'': 'a', '': 'i', '': 'u', '': 'e', '': 'o',
'': 'ka', '': 'ki', '': 'ku', '': 'ke', '': 'ko',
'': 'sa', '': 'shi', '': 'su', '': 'se', '': 'so',
'': 'ta', '': 'chi', '': 'tsu', '': 'te', '': 'to',
'': 'na', '': 'ni', '': 'nu', '': 'ne', '': 'no',
'': 'ha', '': 'hi', '': 'fu', '': 'he', '': 'ho',
'': 'ma', '': 'mi', '': 'mu', '': 'me', '': 'mo',
'': 'ya', '': 'yu', '': 'yo',
'': 'ra', '': 'ri', '': 'ru', '': 're', '': 'ro',
'': 'wa', '': 'wo', '': 'n',
# Hiragana voiced consonants
'': 'ga', '': 'gi', '': 'gu', '': 'ge', '': 'go',
'': 'za', '': 'ji', '': 'zu', '': 'ze', '': 'zo',
'': 'da', '': 'ji', '': 'zu', '': 'de', '': 'do',
'': 'ba', '': 'bi', '': 'bu', '': 'be', '': 'bo',
'': 'pa', '': 'pi', '': 'pu', '': 'pe', '': 'po',
# Hiragana combinations
'きゃ': 'kya', 'きゅ': 'kyu', 'きょ': 'kyo',
'しゃ': 'sha', 'しゅ': 'shu', 'しょ': 'sho',
'ちゃ': 'cha', 'ちゅ': 'chu', 'ちょ': 'cho',
'にゃ': 'nya', 'にゅ': 'nyu', 'にょ': 'nyo',
'ひゃ': 'hya', 'ひゅ': 'hyu', 'ひょ': 'hyo',
'みゃ': 'mya', 'みゅ': 'myu', 'みょ': 'myo',
'りゃ': 'rya', 'りゅ': 'ryu', 'りょ': 'ryo',
'ぎゃ': 'gya', 'ぎゅ': 'gyu', 'ぎょ': 'gyo',
'じゃ': 'ja', 'じゅ': 'ju', 'じょ': 'jo',
'びゃ': 'bya', 'びゅ': 'byu', 'びょ': 'byo',
'ぴゃ': 'pya', 'ぴゅ': 'pyu', 'ぴょ': 'pyo',
# Hiragana small characters and special cases
'': '', # Small tsu (doubles the following consonant)
'': 'ya', '': 'yu', '': 'yo',
# Common punctuation and spaces
' ': ' ', # Japanese space
'': ', ', '': '. ',
}
result = []
i = 0
while i < len(japanese_text):
# Check for small tsu (doubling the following consonant)
if i < len(japanese_text) - 1 and (japanese_text[i] == '' or japanese_text[i] == ''):
if i < len(japanese_text) - 1 and japanese_text[i+1] in kana_map:
next_romaji = kana_map[japanese_text[i+1]]
if next_romaji and next_romaji[0] not in 'aiueon':
result.append(next_romaji[0]) # Double the consonant
i += 1
continue
# Check for combinations with small ya, yu, yo
if i < len(japanese_text) - 1 and japanese_text[i+1] in ('', '', '', '', '', ''):
combo = japanese_text[i:i+2]
if combo in kana_map:
result.append(kana_map[combo])
i += 2
continue
# Regular character
if japanese_text[i] in kana_map:
result.append(kana_map[japanese_text[i]])
else:
# If it's not in our map, keep it as is (might be kanji, romaji, etc.)
result.append(japanese_text[i])
i += 1
return ''.join(result)
def number_to_text(num, ordinal=False):
"""
Convert a number (int or float) to its text representation.
Args:
num: The number to convert
Returns:
str: Text representation of the number
"""
if not isinstance(num, (int, float)):
return "Input must be a number"
# Handle special case of zero
if num == 0:
return "zero"
# Handle negative numbers
negative = num < 0
num = abs(num)
# Handle floats
if isinstance(num, float):
# Split into integer and decimal parts
int_part = int(num)
# Convert both parts
int_text = _int_to_text(int_part)
# Handle decimal part (convert to string and remove '0.')
decimal_str = str(num).split('.')[1]
decimal_text = " point " + " ".join(_digit_to_text(int(digit)) for digit in decimal_str)
result = int_text + decimal_text
else:
# Handle integers
result = _int_to_text(num)
# Add 'negative' prefix for negative numbers
if negative:
result = "negative " + result
return result
def _int_to_text(num):
"""Helper function to convert an integer to text"""
ones = ["", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine",
"ten", "eleven", "twelve", "thirteen", "fourteen", "fifteen", "sixteen",
"seventeen", "eighteen", "nineteen"]
tens = ["", "", "twenty", "thirty", "forty", "fifty", "sixty", "seventy", "eighty", "ninety"]
if num < 20:
return ones[num]
if num < 100:
return tens[num // 10] + (" " + ones[num % 10] if num % 10 != 0 else "")
if num < 1000:
return ones[num // 100] + " hundred" + (" " + _int_to_text(num % 100) if num % 100 != 0 else "")
if num < 1000000:
return _int_to_text(num // 1000) + " thousand" + (" " + _int_to_text(num % 1000) if num % 1000 != 0 else "")
if num < 1000000000:
return _int_to_text(num // 1000000) + " million" + (" " + _int_to_text(num % 1000000) if num % 1000000 != 0 else "")
return _int_to_text(num // 1000000000) + " billion" + (" " + _int_to_text(num % 1000000000) if num % 1000000000 != 0 else "")
def _digit_to_text(digit):
"""Convert a single digit to text"""
digits = ["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
return digits[digit]
_whitespace_re = re.compile(r"\s+")
# List of (regular expression, replacement) pairs for abbreviations:
_abbreviations = {
"en": [
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
for x in [
("mrs", "misess"),
("mr", "mister"),
("dr", "doctor"),
("st", "saint"),
("co", "company"),
("jr", "junior"),
("maj", "major"),
("gen", "general"),
("drs", "doctors"),
("rev", "reverend"),
("lt", "lieutenant"),
("hon", "honorable"),
("sgt", "sergeant"),
("capt", "captain"),
("esq", "esquire"),
("ltd", "limited"),
("col", "colonel"),
("ft", "fort"),
]
],
}
def expand_abbreviations_multilingual(text, lang="en"):
for regex, replacement in _abbreviations[lang]:
text = re.sub(regex, replacement, text)
return text
_symbols_multilingual = {
"en": [
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
for x in [
("&", " and "),
("@", " at "),
("%", " percent "),
("#", " hash "),
("$", " dollar "),
("£", " pound "),
("°", " degree "),
]
],
}
def expand_symbols_multilingual(text, lang="en"):
for regex, replacement in _symbols_multilingual[lang]:
text = re.sub(regex, replacement, text)
text = text.replace(" ", " ") # Ensure there are no double spaces
return text.strip()
_ordinal_re = {
"en": re.compile(r"([0-9]+)(st|nd|rd|th)"),
}
_number_re = re.compile(r"[0-9]+")
_currency_re = {
"USD": re.compile(r"((\$[0-9\.\,]*[0-9]+)|([0-9\.\,]*[0-9]+\$))"),
"GBP": re.compile(r"((£[0-9\.\,]*[0-9]+)|([0-9\.\,]*[0-9]+£))"),
"EUR": re.compile(r"(([0-9\.\,]*[0-9]+€)|((€[0-9\.\,]*[0-9]+)))"),
}
_comma_number_re = re.compile(r"\b\d{1,3}(,\d{3})*(\.\d+)?\b")
_dot_number_re = re.compile(r"\b\d{1,3}(.\d{3})*(\,\d+)?\b")
_decimal_number_re = re.compile(r"([0-9]+[.,][0-9]+)")
def _remove_commas(m):
text = m.group(0)
if "," in text:
text = text.replace(",", "")
return text
def _remove_dots(m):
text = m.group(0)
if "." in text:
text = text.replace(".", "")
return text
def _expand_decimal_point(m, lang="en"):
amount = m.group(1).replace(",", ".")
return number_to_text(float(amount))
def _expand_currency(m, lang="en", currency="USD"):
amount = float((re.sub(r"[^\d.]", "", m.group(0).replace(",", "."))))
full_amount = number_to_text(amount)
and_equivalents = {
"en": ", ",
"es": " con ",
"fr": " et ",
"de": " und ",
"pt": " e ",
"it": " e ",
"pl": ", ",
"cs": ", ",
"ru": ", ",
"nl": ", ",
"ar": ", ",
"tr": ", ",
"hu": ", ",
"ko": ", ",
}
if amount.is_integer():
last_and = full_amount.rfind(and_equivalents[lang])
if last_and != -1:
full_amount = full_amount[:last_and]
return full_amount
def _expand_ordinal(m, lang="en"):
return number_to_text(int(m.group(1)), ordinal=True)
def _expand_number(m, lang="en"):
return number_to_text(int(m.group(0)))
def expand_numbers_multilingual(text, lang="en"):
if lang in ["en", "ru"]:
text = re.sub(_comma_number_re, _remove_commas, text)
else:
text = re.sub(_dot_number_re, _remove_dots, text)
try:
text = re.sub(_currency_re["GBP"], lambda m: _expand_currency(m, lang, "GBP"), text)
text = re.sub(_currency_re["USD"], lambda m: _expand_currency(m, lang, "USD"), text)
text = re.sub(_currency_re["EUR"], lambda m: _expand_currency(m, lang, "EUR"), text)
except:
pass
text = re.sub(_decimal_number_re, lambda m: _expand_decimal_point(m, lang), text)
text = re.sub(_ordinal_re[lang], lambda m: _expand_ordinal(m, lang), text)
text = re.sub(_number_re, lambda m: _expand_number(m, lang), text)
return text
def lowercase(text):
return text.lower()
def collapse_whitespace(text):
return re.sub(_whitespace_re, " ", text)
def multilingual_cleaners(text, lang):
text = text.replace('"', "")
if lang == "tr":
text = text.replace("İ", "i")
text = text.replace("Ö", "ö")
text = text.replace("Ü", "ü")
text = lowercase(text)
try:
text = expand_numbers_multilingual(text, lang)
except:
pass
try:
text = expand_abbreviations_multilingual(text, lang)
except:
pass
try:
text = expand_symbols_multilingual(text, lang=lang)
except:
pass
text = collapse_whitespace(text)
return text
def basic_cleaners(text):
"""Basic pipeline that lowercases and collapses whitespace without transliteration."""
text = lowercase(text)
text = collapse_whitespace(text)
return text