llm-asr-tts/3.07backend_service.py

158 lines
5.2 KiB
Python
Raw Normal View History

2025-03-16 16:41:41 +00:00
from flask import Flask, request, jsonify, send_from_directory
from flask_cors import CORS
import time
import os
import logging
from transformers import AutoTokenizer, AutoModelForCausalLM
from funasr import AutoModel
import edge_tts
import asyncio
import langid
import tempfile
app = Flask(__name__)
CORS(app)
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler('app.log'), # 文件日志
logging.StreamHandler() # 控制台日志
]
)
# 配置参数
AUDIO_RATE = 16000
OUTPUT_DIR = "./output"
os.makedirs(OUTPUT_DIR, exist_ok=True)
# 初始化模型
app.logger.info("Loading models...")
model_dir = "D:/AI/download/SenseVoiceSmall"
model_senceVoice = AutoModel(model=model_dir, trust_remote_code=True)
# 加载Qwen2.5大语言模型
model_name = "D:/AI/download/Qwen2.5-1.5B-Instruct"
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype="auto",
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
app.logger.info("All models loaded!")
# 语言映射
language_speaker = {
"ja": "ja-JP-NanamiNeural",
"fr": "fr-FR-DeniseNeural",
"es": "ca-ES-JoanaNeural",
"de": "de-DE-KatjaNeural",
"zh": "zh-CN-XiaoyiNeural",
"en": "en-US-AnaNeural",
}
# ---------------------- 接口路由 ----------------------
@app.route('/asr', methods=['POST'])
def handle_asr():
"""处理语音识别请求"""
if 'audio' not in request.files:
return jsonify({"error": "No audio file provided"}), 400
try:
audio_file = request.files['audio']
app.logger.info(f"Received audio file: {audio_file.filename}")
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
audio_file.save(tmp.name)
# 语音识别
res = model_senceVoice.generate(
input=tmp.name,
cache={},
language="auto",
use_itn=False,
)
asr_text = res[0]['text'].split(">")[-1]
app.logger.info(f"ASR识别结果: {asr_text}")
return jsonify({"asr_text": asr_text})
except Exception as e:
app.logger.error(f"ASR处理异常: {str(e)}")
return jsonify({"error": str(e)}), 500
@app.route('/generate_text', methods=['POST'])
def handle_generate_text():
"""处理大模型文本生成请求"""
try:
data = request.get_json()
asr_text = data.get('asr_text', '')
app.logger.info(f"收到ASR文本: {asr_text}")
if not asr_text:
return jsonify({"error": "No ASR text provided"}), 400
# 构建对话模板
messages = [
{"role": "system", "content": "你叫千问是一个18岁的女大学生性格活泼开朗说话俏皮"},
{"role": "user", "content": asr_text},
]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
# 生成回复
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
generated_ids = model.generate(**model_inputs, max_new_tokens=512)
generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids
in zip(model_inputs.input_ids, generated_ids)]
answer_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
app.logger.info(f"大模型回复: {answer_text}")
return jsonify({"answer_text": answer_text})
except Exception as e:
app.logger.error(f"文本生成异常: {str(e)}")
return jsonify({"error": str(e)}), 500
@app.route('/generate_audio', methods=['POST'])
def handle_generate_audio():
"""处理语音合成请求"""
try:
data = request.get_json()
answer_text = data.get('answer_text', '')
app.logger.info(f"收到待合成文本: {answer_text}")
if not answer_text:
return jsonify({"error": "No answer text provided"}), 400
# 语种识别
lang, _ = langid.classify(answer_text)
speaker = language_speaker.get(lang, "zh-CN-XiaoyiNeural")
app.logger.info(f"识别到语言: {lang}, 使用发音人: {speaker}")
# 语音合成
output_file = os.path.join(OUTPUT_DIR, f"response_{int(time.time())}.mp3")
asyncio.run(edge_tts.Communicate(answer_text, speaker).save(output_file))
app.logger.info(f"语音合成完成,保存路径: {output_file}")
return jsonify({
"audio_url": f"/audio/{os.path.basename(output_file)}"
})
except Exception as e:
app.logger.error(f"语音合成异常: {str(e)}")
return jsonify({"error": str(e)}), 500
@app.route('/audio/<filename>')
def get_audio(filename):
"""音频文件下载接口"""
return send_from_directory(OUTPUT_DIR, filename)
if __name__ == '__main__':
app.logger.info("服务启动端口5000")
app.run(port=5000, threaded=True)