llm-asr-tts/0_Inference_QWen2.5.py

74 lines
2.3 KiB
Python
Raw Permalink Normal View History

2025-03-16 16:41:41 +00:00
from transformers import AutoModelForCausalLM, AutoTokenizer
from cosyvoice.cli.cosyvoice import CosyVoice
from cosyvoice.utils.file_utils import load_wav
import torchaudio
import pygame
import time
def play_audio(file_path):
try:
pygame.mixer.init()
pygame.mixer.music.load(file_path)
pygame.mixer.music.play()
while pygame.mixer.music.get_busy():
time.sleep(1) # 等待音频播放结束
print("播放完成!")
except Exception as e:
print(f"播放失败: {e}")
finally:
pygame.mixer.quit()
model_name = r"D:\AI\download\Qwen2.5-0.5B-Instruct"
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype="auto",
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
cosyvoice = CosyVoice(r'../pretrained_models/CosyVoice-300M', load_jit=True, load_onnx=False, fp16=True)
# sft usage
print(cosyvoice.list_avaliable_spks())
prompt = "你好,你叫什么名字?"
messages = [
{"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."},
{"role": "user", "content": prompt},
]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
generated_ids = model.generate(
**model_inputs,
max_new_tokens=512,
)
generated_ids = [
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
print("Input:", prompt)
print("Answer:", response)
# ['中文女', '中文男', '日语男', '粤语女', '英文女', '英文男', '韩语女']
for i, j in enumerate(cosyvoice.inference_sft(f'{prompt}', '中文男', stream=False)):
torchaudio.save('prompt_sft_{}.wav'.format(i), j['tts_speech'], 22050)
# play_audio('prompt_sft_{}.wav'.format(i))
# change stream=True for chunk stream inference
for i, j in enumerate(cosyvoice.inference_sft(f'{response}', '中文女', stream=False)):
torchaudio.save('sft_{}.wav'.format(i), j['tts_speech'], 22050)
# play_audio('sft_{}.wav'.format(i))
play_audio('prompt_sft_0.wav')
for i, j in enumerate(f'{response}'):
play_audio('sft_{}.wav'.format(i))