Files
ray-serve/ray_serve/serve_tts.py
Billy D. 84ffeca8f2
All checks were successful
Build and Publish ray-serve-apps / build-and-publish (push) Successful in 1m54s
fix(tts): add /health endpoint, fix language param for single-lang models
- Add GET /health endpoint returning model name and GPU status
- Don't pass language/speaker to Coqui TTS when model doesn't support
  multilingual/multi-speaker (fixes 500 on ljspeech/tacotron2-DDC)
- Applied to all three endpoints: POST /, GET /api/tts, POST /stream
2026-02-22 12:19:06 -05:00

296 lines
11 KiB
Python

"""
Ray Serve deployment for Coqui TTS.
Runs on: elminster (RTX 2070 8GB, CUDA)
Provides three API styles:
POST /tts — JSON body → JSON response with base64 audio
POST /tts/stream — JSON body → SSE with per-sentence base64 audio chunks
GET /tts/api/tts — Coqui-compatible query params → raw WAV bytes
"""
import base64
import io
import json
import os
import re
import time
from typing import Any
from fastapi import FastAPI, Query, Request
from fastapi.responses import Response, StreamingResponse
from ray import serve
try:
from ray_serve.mlflow_logger import InferenceLogger
except ImportError:
InferenceLogger = None
_fastapi = FastAPI()
# ── Sentence splitting for streaming ─────────────────────────────────────
_SENTENCE_RE = re.compile(r"(?<=[.!?;])\s+|(?<=\n)\s*", re.MULTILINE)
def _split_sentences(text: str, max_len: int = 200) -> list[str]:
"""Split text into sentences suitable for per-sentence TTS streaming."""
text = re.sub(r"[ \t]+", " ", text).strip()
if not text:
return []
raw_parts = _SENTENCE_RE.split(text)
sentences: list[str] = []
for part in raw_parts:
part = part.strip()
if not part:
continue
if len(part) > max_len:
for sp in re.split(r"(?<=[,;:])\s+", part):
sp = sp.strip()
if sp:
sentences.append(sp)
else:
sentences.append(part)
return sentences
@serve.deployment(name="TTSDeployment", num_replicas=1)
@serve.ingress(_fastapi)
class TTSDeployment:
def __init__(self):
import torch
from TTS.api import TTS
self.model_name = os.environ.get("MODEL_NAME", "tts_models/en/ljspeech/tacotron2-DDC")
# Detect device
self.use_gpu = torch.cuda.is_available()
print(f"Loading TTS model: {self.model_name}")
print(f"Using GPU: {self.use_gpu}")
self.tts = TTS(model_name=self.model_name, progress_bar=False)
if self.use_gpu:
self.tts = self.tts.to("cuda")
print("TTS model loaded successfully")
# MLflow metrics
if InferenceLogger is not None:
self._mlflow = InferenceLogger(
experiment_name="ray-serve-tts",
run_name=f"tts-{self.model_name.split('/')[-1]}",
tags={"model.name": self.model_name, "model.framework": "coqui-tts", "gpu": str(self.use_gpu)},
flush_every=5,
)
self._mlflow.initialize(params={"model_name": self.model_name, "use_gpu": str(self.use_gpu)})
else:
self._mlflow = None
# ── internal synthesis helpers ────────────────────────────────────────
def _synthesize(self, text: str, speaker: str | None = None,
language: str | None = None, speed: float = 1.0):
"""Return (wav_bytes: bytes, sample_rate: int, duration: float)."""
import numpy as np
from scipy.io import wavfile
wav = self.tts.tts(text=text, speaker=speaker, language=language, speed=speed)
if not isinstance(wav, np.ndarray):
wav = np.array(wav)
wav_int16 = (wav * 32767).astype(np.int16)
sample_rate = (
self.tts.synthesizer.output_sample_rate
if hasattr(self.tts, "synthesizer")
else 22050
)
buf = io.BytesIO()
wavfile.write(buf, sample_rate, wav_int16)
return buf.getvalue(), sample_rate, len(wav) / sample_rate
def _log(self, start: float, duration: float, text_len: int):
if self._mlflow:
elapsed = time.time() - start
self._mlflow.log_request(
latency_s=elapsed,
audio_duration_s=duration,
text_chars=text_len,
realtime_factor=elapsed / duration if duration > 0 else 0,
)
# ── GET /health — simple liveness check ─────────────────────────────
@_fastapi.get("/health")
def health(self) -> dict[str, Any]:
"""Simple health/readiness check."""
return {
"status": "ok",
"model": self.model_name,
"gpu": self.use_gpu,
}
# ── POST / — JSON API (base64 audio in response) ────────────────────
@_fastapi.post("/")
async def generate_json(self, request: dict[str, Any]) -> dict[str, Any]:
"""
JSON API — POST body:
{"text": "...", "speaker": "...", "language": "en", "speed": 1.0,
"output_format": "wav", "return_base64": true}
"""
_start = time.time()
text = request.get("text", "")
if not text:
return {"error": "No text provided"}
speaker = request.get("speaker")
language = request.get("language")
speed = request.get("speed", 1.0)
output_format = request.get("output_format", "wav")
return_base64 = request.get("return_base64", True)
# Only pass language/speaker if the model supports it
if not (hasattr(self.tts, "is_multi_lingual") and self.tts.is_multi_lingual):
language = None
if not (hasattr(self.tts, "is_multi_speaker") and self.tts.is_multi_speaker):
speaker = None
try:
audio_bytes, sample_rate, duration = self._synthesize(
text, speaker, language, speed
)
self._log(_start, duration, len(text))
resp: dict[str, Any] = {
"model": self.model_name,
"sample_rate": sample_rate,
"duration": duration,
"format": output_format,
}
if return_base64:
resp["audio"] = base64.b64encode(audio_bytes).decode("utf-8")
else:
resp["audio_bytes"] = audio_bytes
return resp
except Exception as e:
return {"error": str(e), "model": self.model_name}
# ── GET /api/tts — Coqui-compatible raw WAV endpoint ─────────────────
@_fastapi.get("/api/tts")
async def generate_raw(
self,
text: str = Query(..., description="Text to synthesize"),
language_id: str = Query("en", description="Language code"),
speaker_id: str | None = Query(None, description="Speaker name"),
) -> Response:
"""Coqui XTTS-compatible endpoint — returns raw WAV bytes."""
_start = time.time()
if not text:
return Response(content="text parameter required", status_code=400)
# Only pass language/speaker if the model is multi-lingual/multi-speaker
lang = language_id if hasattr(self.tts, "is_multi_lingual") and self.tts.is_multi_lingual else None
spk = speaker_id if hasattr(self.tts, "is_multi_speaker") and self.tts.is_multi_speaker else None
try:
audio_bytes, _sr, duration = self._synthesize(
text, spk, lang
)
self._log(_start, duration, len(text))
return Response(content=audio_bytes, media_type="audio/wav")
except Exception as e:
return Response(content=str(e), status_code=500)
# ── POST /stream — SSE per-sentence streaming ──────────────────────
@_fastapi.post("/stream")
async def generate_stream(self, request: Request) -> StreamingResponse:
"""Stream TTS audio as SSE events, one per sentence.
Request body:
{"text": "...", "language": "en", "speaker": null, "speed": 1.0}
Each SSE event is a JSON object:
data: {"text": "sentence", "audio": "<base64 WAV>", "index": 0,
"sample_rate": 24000, "duration": 1.23, "done": false}
Final event:
data: {"text": "", "audio": "", "index": N, "done": true}
followed by:
data: [DONE]
"""
body = await request.json()
text = body.get("text", "")
if not text:
async def _empty():
yield 'data: {"error": "No text provided"}\n\n'
yield "data: [DONE]\n\n"
return StreamingResponse(_empty(), media_type="text/event-stream")
speaker = body.get("speaker")
language = body.get("language")
speed = body.get("speed", 1.0)
# Only pass language/speaker if the model supports it
if not (hasattr(self.tts, "is_multi_lingual") and self.tts.is_multi_lingual):
language = None
if not (hasattr(self.tts, "is_multi_speaker") and self.tts.is_multi_speaker):
speaker = None
sentences = _split_sentences(text)
async def _generate():
for idx, sentence in enumerate(sentences):
_start = time.time()
try:
audio_bytes, sample_rate, duration = self._synthesize(
sentence, speaker, language, speed
)
self._log(_start, duration, len(sentence))
chunk = {
"text": sentence,
"audio": base64.b64encode(audio_bytes).decode("utf-8"),
"index": idx,
"sample_rate": sample_rate,
"duration": duration,
"done": False,
}
except Exception as e:
chunk = {
"text": sentence,
"audio": "",
"index": idx,
"error": str(e),
"done": False,
}
yield f"data: {json.dumps(chunk)}\n\n"
yield f'data: {json.dumps({"text": "", "audio": "", "index": len(sentences), "done": True})}\n\n'
yield "data: [DONE]\n\n"
return StreamingResponse(
_generate(),
media_type="text/event-stream",
headers={"X-Accel-Buffering": "no", "Cache-Control": "no-cache"},
)
# ── GET /speakers — list available speakers ──────────────────────────
@_fastapi.get("/speakers")
def list_speakers(self) -> dict[str, Any]:
"""List available speakers for multi-speaker models."""
speakers = []
if hasattr(self.tts, "speakers") and self.tts.speakers:
speakers = self.tts.speakers
return {
"model": self.model_name,
"speakers": speakers,
"is_multi_speaker": len(speakers) > 0,
}
app = TTSDeployment.bind()