From 20c1eaeb6408993641fcb0529c25cf0716a4fe40 Mon Sep 17 00:00:00 2001 From: Joshua Meyer Date: Wed, 13 Dec 2023 01:37:29 +0100 Subject: [PATCH] Allow running on CPU --- server/Dockerfile.cpu | 21 +++++ server/main_cpu.py | 195 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 216 insertions(+) create mode 100644 server/Dockerfile.cpu create mode 100644 server/main_cpu.py diff --git a/server/Dockerfile.cpu b/server/Dockerfile.cpu new file mode 100644 index 0000000..497eb59 --- /dev/null +++ b/server/Dockerfile.cpu @@ -0,0 +1,21 @@ +FROM pytorch/pytorch:latest +ARG DEBIAN_FRONTEND=noninteractive + +RUN apt-get update && \ + apt-get install --no-install-recommends -y sox libsox-fmt-all curl wget gcc git git-lfs build-essential libaio-dev libsndfile1 ssh ffmpeg && \ + apt-get clean && apt-get -y autoremove + +WORKDIR /app +COPY requirements.txt . +RUN python -m pip install --use-deprecated=legacy-resolver -r requirements.txt \ + && python -m pip cache purge + +RUN python -m unidic download +RUN mkdir -p /app/tts_models + +COPY main_cpu.py . +ENV NVIDIA_DISABLE_REQUIRE=0 + +ENV NUM_THREADS=2 +EXPOSE 80 +CMD ["uvicorn", "main_cpu:app", "--host", "0.0.0.0", "--port", "80"] diff --git a/server/main_cpu.py b/server/main_cpu.py new file mode 100644 index 0000000..544d751 --- /dev/null +++ b/server/main_cpu.py @@ -0,0 +1,195 @@ +import base64 +import io +import os +import tempfile +from typing import List, Literal +import wave + +import numpy as np +import torch +from fastapi import ( + FastAPI, + UploadFile, + Body, +) +from pydantic import BaseModel +from fastapi.responses import StreamingResponse + +from TTS.tts.configs.xtts_config import XttsConfig +from TTS.tts.models.xtts import Xtts +from TTS.utils.generic_utils import get_user_data_dir +from TTS.utils.manage import ModelManager + +torch.set_num_threads(int(os.environ.get("NUM_THREADS", "2"))) +device = torch.device("cpu") + +custom_model_path = os.environ.get("CUSTOM_MODEL_PATH", "/app/tts_models") + +if os.path.exists(custom_model_path) and os.path.isfile(custom_model_path + "/config.json"): + model_path = custom_model_path + print("Loading custom model from", model_path, flush=True) +else: + print("Loading default model", flush=True) + model_name = "tts_models/multilingual/multi-dataset/xtts_v2" + print("Downloading XTTS Model:", model_name, flush=True) + ModelManager().download_model(model_name) + model_path = os.path.join(get_user_data_dir("tts"), model_name.replace("/", "--")) + print("XTTS Model downloaded", flush=True) + +print("Loading XTTS", flush=True) +config = XttsConfig() +config.load_json(os.path.join(model_path, "config.json")) +model = Xtts.init_from_config(config) +model.load_checkpoint(config, checkpoint_dir=model_path, eval=True, use_deepspeed=False) +model.to(device) +print("XTTS Loaded.", flush=True) + +print("Running XTTS Server ...", flush=True) + +##### Run fastapi ##### +app = FastAPI( + title="XTTS Streaming server", + description="""XTTS Streaming server""", + version="0.0.1", + docs_url="/", +) + + +@app.post("/clone_speaker") +def predict_speaker(wav_file: UploadFile): + """Compute conditioning inputs from reference audio file.""" + temp_audio_name = next(tempfile._get_candidate_names()) + with open(temp_audio_name, "wb") as temp, torch.inference_mode(): + temp.write(io.BytesIO(wav_file.file.read()).getbuffer()) + gpt_cond_latent, speaker_embedding = model.get_conditioning_latents( + temp_audio_name + ) + return { + "gpt_cond_latent": gpt_cond_latent.cpu().squeeze().half().tolist(), + "speaker_embedding": speaker_embedding.cpu().squeeze().half().tolist(), + } + + +def postprocess(wav): + """Post process the output waveform""" + if isinstance(wav, list): + wav = torch.cat(wav, dim=0) + wav = wav.clone().detach().cpu().numpy() + wav = wav[None, : int(wav.shape[0])] + wav = np.clip(wav, -1, 1) + wav = (wav * 32767).astype(np.int16) + return wav + + +def encode_audio_common( + frame_input, encode_base64=True, sample_rate=24000, sample_width=2, channels=1 +): + """Return base64 encoded audio""" + wav_buf = io.BytesIO() + with wave.open(wav_buf, "wb") as vfout: + vfout.setnchannels(channels) + vfout.setsampwidth(sample_width) + vfout.setframerate(sample_rate) + vfout.writeframes(frame_input) + + wav_buf.seek(0) + if encode_base64: + b64_encoded = base64.b64encode(wav_buf.getbuffer()).decode("utf-8") + return b64_encoded + else: + return wav_buf.read() + + +class StreamingInputs(BaseModel): + speaker_embedding: List[float] + gpt_cond_latent: List[List[float]] + text: str + language: str + add_wav_header: bool = True + stream_chunk_size: str = "20" + + +def predict_streaming_generator(parsed_input: dict = Body(...)): + speaker_embedding = ( + torch.tensor(parsed_input.speaker_embedding).unsqueeze(0).unsqueeze(-1) + ) + gpt_cond_latent = ( + torch.tensor(parsed_input.gpt_cond_latent).reshape((-1, 1024)).unsqueeze(0) + ) + text = parsed_input.text + language = parsed_input.language + + stream_chunk_size = int(parsed_input.stream_chunk_size) + add_wav_header = parsed_input.add_wav_header + + + chunks = model.inference_stream( + text, + language, + gpt_cond_latent, + speaker_embedding, + stream_chunk_size=stream_chunk_size, + enable_text_splitting=True + ) + + for i, chunk in enumerate(chunks): + chunk = postprocess(chunk) + if i == 0 and add_wav_header: + yield encode_audio_common(b"", encode_base64=False) + yield chunk.tobytes() + else: + yield chunk.tobytes() + + +@app.post("/tts_stream") +def predict_streaming_endpoint(parsed_input: StreamingInputs): + return StreamingResponse( + predict_streaming_generator(parsed_input), + media_type="audio/wav", + ) + +class TTSInputs(BaseModel): + speaker_embedding: List[float] + gpt_cond_latent: List[List[float]] + text: str + language: str + +@app.post("/tts") +def predict_speech(parsed_input: TTSInputs): + speaker_embedding = ( + torch.tensor(parsed_input.speaker_embedding).unsqueeze(0).unsqueeze(-1) + ) + gpt_cond_latent = ( + torch.tensor(parsed_input.gpt_cond_latent).reshape((-1, 1024)).unsqueeze(0) + ) + text = parsed_input.text + language = parsed_input.language + + out = model.inference( + text, + language, + gpt_cond_latent, + speaker_embedding, + ) + + wav = postprocess(torch.tensor(out["wav"])) + + return encode_audio_common(wav.tobytes()) + + +@app.get("/studio_speakers") +def get_speakers(): + if hasattr(model, "speaker_manager") and hasattr(model.speaker_manager, "speakers"): + return { + speaker: { + "speaker_embedding": model.speaker_manager.speakers[speaker]["speaker_embedding"].cpu().squeeze().half().tolist(), + "gpt_cond_latent": model.speaker_manager.speakers[speaker]["gpt_cond_latent"].cpu().squeeze().half().tolist(), + } + for speaker in model.speaker_manager.speakers.keys() + } + else: + return {} + +@app.get("/languages") +def get_languages(): + return config.languages