169 lines
6.1 KiB
Python
169 lines
6.1 KiB
Python
|
import argparse
|
||
|
import os
|
||
|
import warnings
|
||
|
from pathlib import Path
|
||
|
from time import perf_counter
|
||
|
|
||
|
import numpy as np
|
||
|
import onnxruntime as ort
|
||
|
import soundfile as sf
|
||
|
import torch
|
||
|
|
||
|
from matcha.cli import plot_spectrogram_to_numpy, process_text
|
||
|
|
||
|
|
||
|
def validate_args(args):
|
||
|
assert (
|
||
|
args.text or args.file
|
||
|
), "Either text or file must be provided Matcha-T(ea)TTS need sometext to whisk the waveforms."
|
||
|
assert args.temperature >= 0, "Sampling temperature cannot be negative"
|
||
|
assert args.speaking_rate >= 0, "Speaking rate must be greater than 0"
|
||
|
return args
|
||
|
|
||
|
|
||
|
def write_wavs(model, inputs, output_dir, external_vocoder=None):
|
||
|
if external_vocoder is None:
|
||
|
print("The provided model has the vocoder embedded in the graph.\nGenerating waveform directly")
|
||
|
t0 = perf_counter()
|
||
|
wavs, wav_lengths = model.run(None, inputs)
|
||
|
infer_secs = perf_counter() - t0
|
||
|
mel_infer_secs = vocoder_infer_secs = None
|
||
|
else:
|
||
|
print("[🍵] Generating mel using Matcha")
|
||
|
mel_t0 = perf_counter()
|
||
|
mels, mel_lengths = model.run(None, inputs)
|
||
|
mel_infer_secs = perf_counter() - mel_t0
|
||
|
print("Generating waveform from mel using external vocoder")
|
||
|
vocoder_inputs = {external_vocoder.get_inputs()[0].name: mels}
|
||
|
vocoder_t0 = perf_counter()
|
||
|
wavs = external_vocoder.run(None, vocoder_inputs)[0]
|
||
|
vocoder_infer_secs = perf_counter() - vocoder_t0
|
||
|
wavs = wavs.squeeze(1)
|
||
|
wav_lengths = mel_lengths * 256
|
||
|
infer_secs = mel_infer_secs + vocoder_infer_secs
|
||
|
|
||
|
output_dir = Path(output_dir)
|
||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||
|
for i, (wav, wav_length) in enumerate(zip(wavs, wav_lengths)):
|
||
|
output_filename = output_dir.joinpath(f"output_{i + 1}.wav")
|
||
|
audio = wav[:wav_length]
|
||
|
print(f"Writing audio to {output_filename}")
|
||
|
sf.write(output_filename, audio, 22050, "PCM_24")
|
||
|
|
||
|
wav_secs = wav_lengths.sum() / 22050
|
||
|
print(f"Inference seconds: {infer_secs}")
|
||
|
print(f"Generated wav seconds: {wav_secs}")
|
||
|
rtf = infer_secs / wav_secs
|
||
|
if mel_infer_secs is not None:
|
||
|
mel_rtf = mel_infer_secs / wav_secs
|
||
|
print(f"Matcha RTF: {mel_rtf}")
|
||
|
if vocoder_infer_secs is not None:
|
||
|
vocoder_rtf = vocoder_infer_secs / wav_secs
|
||
|
print(f"Vocoder RTF: {vocoder_rtf}")
|
||
|
print(f"Overall RTF: {rtf}")
|
||
|
|
||
|
|
||
|
def write_mels(model, inputs, output_dir):
|
||
|
t0 = perf_counter()
|
||
|
mels, mel_lengths = model.run(None, inputs)
|
||
|
infer_secs = perf_counter() - t0
|
||
|
|
||
|
output_dir = Path(output_dir)
|
||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||
|
for i, mel in enumerate(mels):
|
||
|
output_stem = output_dir.joinpath(f"output_{i + 1}")
|
||
|
plot_spectrogram_to_numpy(mel.squeeze(), output_stem.with_suffix(".png"))
|
||
|
np.save(output_stem.with_suffix(".numpy"), mel)
|
||
|
|
||
|
wav_secs = (mel_lengths * 256).sum() / 22050
|
||
|
print(f"Inference seconds: {infer_secs}")
|
||
|
print(f"Generated wav seconds: {wav_secs}")
|
||
|
rtf = infer_secs / wav_secs
|
||
|
print(f"RTF: {rtf}")
|
||
|
|
||
|
|
||
|
def main():
|
||
|
parser = argparse.ArgumentParser(
|
||
|
description=" 🍵 Matcha-TTS: A fast TTS architecture with conditional flow matching"
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"model",
|
||
|
type=str,
|
||
|
help="ONNX model to use",
|
||
|
)
|
||
|
parser.add_argument("--vocoder", type=str, default=None, help="Vocoder to use (defaults to None)")
|
||
|
parser.add_argument("--text", type=str, default=None, help="Text to synthesize")
|
||
|
parser.add_argument("--file", type=str, default=None, help="Text file to synthesize")
|
||
|
parser.add_argument("--spk", type=int, default=None, help="Speaker ID")
|
||
|
parser.add_argument(
|
||
|
"--temperature",
|
||
|
type=float,
|
||
|
default=0.667,
|
||
|
help="Variance of the x0 noise (default: 0.667)",
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"--speaking-rate",
|
||
|
type=float,
|
||
|
default=1.0,
|
||
|
help="change the speaking rate, a higher value means slower speaking rate (default: 1.0)",
|
||
|
)
|
||
|
parser.add_argument("--gpu", action="store_true", help="Use CPU for inference (default: use GPU if available)")
|
||
|
parser.add_argument(
|
||
|
"--output-dir",
|
||
|
type=str,
|
||
|
default=os.getcwd(),
|
||
|
help="Output folder to save results (default: current dir)",
|
||
|
)
|
||
|
|
||
|
args = parser.parse_args()
|
||
|
args = validate_args(args)
|
||
|
|
||
|
if args.gpu:
|
||
|
providers = ["GPUExecutionProvider"]
|
||
|
else:
|
||
|
providers = ["CPUExecutionProvider"]
|
||
|
model = ort.InferenceSession(args.model, providers=providers)
|
||
|
|
||
|
model_inputs = model.get_inputs()
|
||
|
model_outputs = list(model.get_outputs())
|
||
|
|
||
|
if args.text:
|
||
|
text_lines = args.text.splitlines()
|
||
|
else:
|
||
|
with open(args.file, encoding="utf-8") as file:
|
||
|
text_lines = file.read().splitlines()
|
||
|
|
||
|
processed_lines = [process_text(0, line, "cpu") for line in text_lines]
|
||
|
x = [line["x"].squeeze() for line in processed_lines]
|
||
|
# Pad
|
||
|
x = torch.nn.utils.rnn.pad_sequence(x, batch_first=True)
|
||
|
x = x.detach().cpu().numpy()
|
||
|
x_lengths = np.array([line["x_lengths"].item() for line in processed_lines], dtype=np.int64)
|
||
|
inputs = {
|
||
|
"x": x,
|
||
|
"x_lengths": x_lengths,
|
||
|
"scales": np.array([args.temperature, args.speaking_rate], dtype=np.float32),
|
||
|
}
|
||
|
is_multi_speaker = len(model_inputs) == 4
|
||
|
if is_multi_speaker:
|
||
|
if args.spk is None:
|
||
|
args.spk = 0
|
||
|
warn = "[!] Speaker ID not provided! Using speaker ID 0"
|
||
|
warnings.warn(warn, UserWarning)
|
||
|
inputs["spks"] = np.repeat(args.spk, x.shape[0]).astype(np.int64)
|
||
|
|
||
|
has_vocoder_embedded = model_outputs[0].name == "wav"
|
||
|
if has_vocoder_embedded:
|
||
|
write_wavs(model, inputs, args.output_dir)
|
||
|
elif args.vocoder:
|
||
|
external_vocoder = ort.InferenceSession(args.vocoder, providers=providers)
|
||
|
write_wavs(model, inputs, args.output_dir, external_vocoder=external_vocoder)
|
||
|
else:
|
||
|
warn = "[!] A vocoder is not embedded in the graph nor an external vocoder is provided. The mel output will be written as numpy arrays to `*.npy` files in the output directory"
|
||
|
warnings.warn(warn, UserWarning)
|
||
|
write_mels(model, inputs, args.output_dir)
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
main()
|