From 194615ce043a1db8af3c2ffb2b69b17e1416142d Mon Sep 17 00:00:00 2001 From: linyq Date: Fri, 23 Aug 2024 13:53:13 +0800 Subject: [PATCH] =?UTF-8?q?=E6=96=B0=E5=A2=9E=20fastapi=20=E6=8E=A5?= =?UTF-8?q?=E5=8F=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- TTS/server/server_fastapi.py | 236 +++++++++++++++++++++++++++++++++++ 1 file changed, 236 insertions(+) create mode 100644 TTS/server/server_fastapi.py diff --git a/TTS/server/server_fastapi.py b/TTS/server/server_fastapi.py new file mode 100644 index 0000000000..5d222bbf0e --- /dev/null +++ b/TTS/server/server_fastapi.py @@ -0,0 +1,236 @@ +""" +这是coqui-tts的fastapi接口 +pip install fastapi uvicorn python-multipart +""" +import io +import json +import os +import logging +import shutil +import uuid +import traceback +from pathlib import Path +from threading import Lock +from typing import Union, List, Optional +from pydantic import Field, BaseModel +from fastapi import FastAPI, Query, Form, HTTPException, UploadFile, File +from fastapi.responses import StreamingResponse, FileResponse + +from TTS.utils.manage import ModelManager +from TTS.utils.synthesizer import Synthesizer + +app = FastAPI() +lock = Lock() + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) +path = Path(__file__).parent / "../.models.json" +manager = ModelManager(path) + +# 将正在使用的模型更新为指定的已发布模型。 +model_path = None +config_path = None +speakers_file_path = None +vocoder_path = None +vocoder_config_path = None + + +def style_wav_uri_to_dict(style_wav: str) -> Union[str, dict, None]: + """ + 该函数功能是将一个uri样式的wave文件路径或GST(Guided Style Tokens) + 字典转换为相应的格式。如果输入是一个字符串类型的路径,且该路径指向一个以".wav"结尾的文件, + 则直接返回该路径。如果输入是一个字符串类型的JSON表示的GST字典,则将其解析为字典格式并返回。 + 如果输入为空字符串或既不是文件路径也不是GST字典的字符串,则返回None。 + Args: + style_wav (str): uri + Returns: + Union[str, dict]: path to file (str) or gst style (dict) + """ + if style_wav: + if os.path.isfile(style_wav) and style_wav.endswith(".wav"): + return style_wav # style_wav 是位于服务器上的.wav文件 + style_wav = json.loads(style_wav) + return style_wav # style_wav 是带有 {token1_id : token1_weigth, ...} 的 GST 字典 + return None + + +@app.get("/list_models", summary="获取模型列表") +async def get_list_models(): + """ + 获取所有可用模型列表 + """ + tts_models = manager.list_tts_models() + vocoder_models = manager.list_vocoder_models() + return {"tts_models": tts_models, "vocoder_models": vocoder_models} + + +@app.get("/load_model", summary="加载模型") +async def load_models( + model_name: str = Query(description="模型名称"), + vocoder_name: str = Query(default=None, description="vocoder名称,一般不用填"), + config_path: str = Query(default=None, description="配置文件路径,一般不用填,只有`tts_models/multilingual/multi-dataset/xtts_v2`模型需要填 `/root/.local/share/tts/tts_models--multilingual--multi-dataset--xtts_v2/config.json`") +): + """ + 加载模型 + """ + global synthesizer + + model_path, tts_config_path, model_item = manager.download_model(model_name) + vocoder_name = model_item["default_vocoder"] if vocoder_name is None else vocoder_name + + if vocoder_name is not None: + vocoder_path, vocoder_config_path, _ = manager.download_model(vocoder_name) + + if config_path is not None: + tts_config_path = config_path + + # 加载模型 + synthesizer = Synthesizer( + tts_checkpoint=model_path, + tts_config_path=tts_config_path, + tts_speakers_file="", + tts_languages_file="", + vocoder_checkpoint="", + vocoder_config="", + encoder_checkpoint="", + encoder_config="", + use_cuda=os.getenv("CUDA", False) + ) + # 是否使用多发言人模式 + speaker_manager = getattr(synthesizer.tts_model, "speaker_manager", None) + # 是否使用多语言模式 + language_manager = getattr(synthesizer.tts_model, "language_manager", None) + + return { + "message": "模型加载成功", + "speaker_manager": speaker_manager, + "language_manager": language_manager + } + + +class TTSRequest(BaseModel): + """ + 语音合成请求 + """ + text: Optional[str] = Field(description="需要转换的文本") + speaker_idx: Optional[str] = Field(default=None, description="说话人 id") + language_idx: Optional[str] = Field(default=None, description="语言 id") + speed: Optional[float] = Field(default=1.0, description="生成音频的速度。默认为 1.0。(如果远低于 1.0,可能会产生伪影)"), + split_sentences: Optional[bool] = Field(default=True, description="将输入文本拆分为句子") + # Todo: 下面参数还没看懂实质上的作用;暂时不添加到接口中 + # style_wav: Optional[str] = Field(default=None, description="GST的样式波形") + # style_text: Optional[str] = Field(default=None, description="Capacitron 的 style_wav 转录") + # reference_wav: Optional[str] = Field(default=None, description="用于语音转换的参考波形") + # reference_speaker_name: Optional[str] = Field(default=None, description="参考波形的扬声器 ID") + +@app.get("/tts", summary="语音合成") +def tts( + text: Optional[str] = Query(description="需要转换的文本"), + speaker_idx: Optional[str] = Query(default=None, description="说话人"), + language_idx: Optional[str] = Query(default=None, description="语种;如:en"), + speed: Optional[float] = Query(default=1.0, description="生成音频的速度。默认为 1.0。(如果远低于 1.0,可能会产生伪影)"), + split_sentences: Optional[bool] = Query(default=True, description="将输入文本拆分为句子") +): + """ + 语音合成 + """ + try: + with lock: + # 使用异常处理来增加健壮性 + try: + wavs = synthesizer.tts( + text=text, + speaker_name=speaker_idx, + language_name=language_idx, + split_sentences=split_sentences, + speed=speed + ) + out = io.BytesIO() + synthesizer.save_wav(wavs, out) + # synthesizer.save_wav(wavs, "output_tts.wav") + except AttributeError as e: + logger.error(f"语音合成失败: {e}") + raise HTTPException(status_code=501, detail="未加载TTS模型") + except Exception: + logger.error(f"语音合成失败: {traceback.format_exc()}") + raise HTTPException(status_code=500, detail="语音合成失败") + except Exception as e: + logger.error(f"处理请求时出错: {traceback.format_exc()}") + raise HTTPException(status_code=500, detail=str(e)) + + headers = { + 'Content-Disposition': 'inline; filename="output.wav"' + } + return StreamingResponse(out, media_type="audio/wav", headers=headers) + # return FileResponse("output_tts.wav", media_type="audio/wav", filename="output.wav") + + +@app.post("/clone_tts", summary="语音克隆并合成") +def clone_tts( + text: Optional[str] = Form(description="需要转换的文本"), + speaker_idx: Optional[str] = Form(default=None, description="说话人"), + language_idx: Optional[str] = Form(default=None, description="语种;如:en"), + speed: Optional[float] = Form(default=1.0, description="生成音频的速度。默认为 1.0。(如果远低于 1.0,可能会产生伪影)"), + split_sentences: Optional[bool] = Form(default=True, description="将输入文本拆分为句子"), + speaker_wav: List[UploadFile] = File(..., description="需要克隆的说话人音频wav文件;支持单个或多个文件;单个音频需要大于3s;wav文件数量和种类决定了克隆的效果") +): + """ + 语音克隆并合成 + """ + try: + with lock: + # 创建缓存目录 + cache_dir = "wav_cache" + if not os.path.exists(cache_dir): + os.makedirs(cache_dir) + + # 保存上传的文件并获取文件路径 + speaker_wav_paths = [] + if speaker_wav is not None: + for file in speaker_wav: + file_id = str(uuid.uuid4()) + file_path = os.path.join(cache_dir, f"{file_id}_{file.filename}") + with open(file_path, "wb") as f: + shutil.copyfileobj(file.file, f) + speaker_wav_paths.append(file_path) + + logger.info(f" > Model input: {text}") + logger.info(f" > Language Idx: {language_idx}") + + # 使用异常处理来增加健壮性 + try: + wavs = synthesizer.tts( + text=text, + language_name=language_idx, + speaker_wav=speaker_wav_paths, + split_sentences=split_sentences, + speed=speed, + speaker_name=speaker_idx + ) + out = io.BytesIO() + synthesizer.save_wav(wavs, out) + except AttributeError as e: + logger.error(f"语音合成失败: {e}") + raise HTTPException(status_code=501, detail="未加载TTS模型") + except Exception: + logger.error(f"语音合成失败: {traceback.format_exc()}") + raise HTTPException(status_code=500, detail="语音合成失败") + except Exception as e: + logger.error(f"处理请求时出错: {e}") + raise HTTPException(status_code=500, detail=str(e)) + finally: + # 清理缓存文件 + for file_path in speaker_wav_paths: + if os.path.exists(file_path): + os.remove(file_path) + + # 如果缓存目录为空,则删除 + if not os.listdir(cache_dir): + os.rmdir(cache_dir) + return StreamingResponse(out, media_type="audio/wav") + + +if __name__ == "__main__": + import uvicorn + + uvicorn.run(app, host='0.0.0.0', port=5002, log_level="debug")