Files
ray-serve/ray_serve/serve_tts.py
Billy D. 194a431e8c
All checks were successful
Build and Publish ray-serve-apps / build-and-publish (push) Successful in 6m53s
feat(tts): add streaming SSE endpoint and sentence splitter
- Add POST /stream SSE endpoint that splits text into sentences,
  synthesizes each individually, and streams base64 WAV via SSE events
- Add _split_sentences() helper for robust sentence boundary detection
- Enables progressive audio playback for lower time-to-first-audio
2026-02-22 10:45:58 -05:00

268 lines
9.5 KiB
Python

"""
Ray Serve deployment for Coqui TTS.
Runs on: elminster (RTX 2070 8GB, CUDA)
Provides three API styles:
POST /tts — JSON body → JSON response with base64 audio
POST /tts/stream — JSON body → SSE with per-sentence base64 audio chunks
GET /tts/api/tts — Coqui-compatible query params → raw WAV bytes
"""
import base64
import io
import json
import os
import re
import time
from typing import Any
from fastapi import FastAPI, Query, Request
from fastapi.responses import Response, StreamingResponse
from ray import serve
try:
from ray_serve.mlflow_logger import InferenceLogger
except ImportError:
InferenceLogger = None
_fastapi = FastAPI()
# ── Sentence splitting for streaming ─────────────────────────────────────
_SENTENCE_RE = re.compile(r"(?<=[.!?;])\s+|(?<=\n)\s*", re.MULTILINE)
def _split_sentences(text: str, max_len: int = 200) -> list[str]:
"""Split text into sentences suitable for per-sentence TTS streaming."""
text = re.sub(r"[ \t]+", " ", text).strip()
if not text:
return []
raw_parts = _SENTENCE_RE.split(text)
sentences: list[str] = []
for part in raw_parts:
part = part.strip()
if not part:
continue
if len(part) > max_len:
for sp in re.split(r"(?<=[,;:])\s+", part):
sp = sp.strip()
if sp:
sentences.append(sp)
else:
sentences.append(part)
return sentences
@serve.deployment(name="TTSDeployment", num_replicas=1)
@serve.ingress(_fastapi)
class TTSDeployment:
def __init__(self):
import torch
from TTS.api import TTS
self.model_name = os.environ.get("MODEL_NAME", "tts_models/en/ljspeech/tacotron2-DDC")
# Detect device
self.use_gpu = torch.cuda.is_available()
print(f"Loading TTS model: {self.model_name}")
print(f"Using GPU: {self.use_gpu}")
self.tts = TTS(model_name=self.model_name, progress_bar=False)
if self.use_gpu:
self.tts = self.tts.to("cuda")
print("TTS model loaded successfully")
# MLflow metrics
if InferenceLogger is not None:
self._mlflow = InferenceLogger(
experiment_name="ray-serve-tts",
run_name=f"tts-{self.model_name.split('/')[-1]}",
tags={"model.name": self.model_name, "model.framework": "coqui-tts", "gpu": str(self.use_gpu)},
flush_every=5,
)
self._mlflow.initialize(params={"model_name": self.model_name, "use_gpu": str(self.use_gpu)})
else:
self._mlflow = None
# ── internal synthesis helpers ────────────────────────────────────────
def _synthesize(self, text: str, speaker: str | None = None,
language: str | None = None, speed: float = 1.0):
"""Return (wav_bytes: bytes, sample_rate: int, duration: float)."""
import numpy as np
from scipy.io import wavfile
wav = self.tts.tts(text=text, speaker=speaker, language=language, speed=speed)
if not isinstance(wav, np.ndarray):
wav = np.array(wav)
wav_int16 = (wav * 32767).astype(np.int16)
sample_rate = (
self.tts.synthesizer.output_sample_rate
if hasattr(self.tts, "synthesizer")
else 22050
)
buf = io.BytesIO()
wavfile.write(buf, sample_rate, wav_int16)
return buf.getvalue(), sample_rate, len(wav) / sample_rate
def _log(self, start: float, duration: float, text_len: int):
if self._mlflow:
elapsed = time.time() - start
self._mlflow.log_request(
latency_s=elapsed,
audio_duration_s=duration,
text_chars=text_len,
realtime_factor=elapsed / duration if duration > 0 else 0,
)
# ── POST / — JSON API (base64 audio in response) ────────────────────
@_fastapi.post("/")
async def generate_json(self, request: dict[str, Any]) -> dict[str, Any]:
"""
JSON API — POST body:
{"text": "...", "speaker": "...", "language": "en", "speed": 1.0,
"output_format": "wav", "return_base64": true}
"""
_start = time.time()
text = request.get("text", "")
if not text:
return {"error": "No text provided"}
speaker = request.get("speaker")
language = request.get("language")
speed = request.get("speed", 1.0)
output_format = request.get("output_format", "wav")
return_base64 = request.get("return_base64", True)
try:
audio_bytes, sample_rate, duration = self._synthesize(
text, speaker, language, speed
)
self._log(_start, duration, len(text))
resp: dict[str, Any] = {
"model": self.model_name,
"sample_rate": sample_rate,
"duration": duration,
"format": output_format,
}
if return_base64:
resp["audio"] = base64.b64encode(audio_bytes).decode("utf-8")
else:
resp["audio_bytes"] = audio_bytes
return resp
except Exception as e:
return {"error": str(e), "model": self.model_name}
# ── GET /api/tts — Coqui-compatible raw WAV endpoint ─────────────────
@_fastapi.get("/api/tts")
async def generate_raw(
self,
text: str = Query(..., description="Text to synthesize"),
language_id: str = Query("en", description="Language code"),
speaker_id: str | None = Query(None, description="Speaker name"),
) -> Response:
"""Coqui XTTS-compatible endpoint — returns raw WAV bytes."""
_start = time.time()
if not text:
return Response(content="text parameter required", status_code=400)
try:
audio_bytes, _sr, duration = self._synthesize(
text, speaker_id, language_id
)
self._log(_start, duration, len(text))
return Response(content=audio_bytes, media_type="audio/wav")
except Exception as e:
return Response(content=str(e), status_code=500)
# ── POST /stream — SSE per-sentence streaming ──────────────────────
@_fastapi.post("/stream")
async def generate_stream(self, request: Request) -> StreamingResponse:
"""Stream TTS audio as SSE events, one per sentence.
Request body:
{"text": "...", "language": "en", "speaker": null, "speed": 1.0}
Each SSE event is a JSON object:
data: {"text": "sentence", "audio": "<base64 WAV>", "index": 0,
"sample_rate": 24000, "duration": 1.23, "done": false}
Final event:
data: {"text": "", "audio": "", "index": N, "done": true}
followed by:
data: [DONE]
"""
body = await request.json()
text = body.get("text", "")
if not text:
async def _empty():
yield 'data: {"error": "No text provided"}\n\n'
yield "data: [DONE]\n\n"
return StreamingResponse(_empty(), media_type="text/event-stream")
speaker = body.get("speaker")
language = body.get("language")
speed = body.get("speed", 1.0)
sentences = _split_sentences(text)
async def _generate():
for idx, sentence in enumerate(sentences):
_start = time.time()
try:
audio_bytes, sample_rate, duration = self._synthesize(
sentence, speaker, language, speed
)
self._log(_start, duration, len(sentence))
chunk = {
"text": sentence,
"audio": base64.b64encode(audio_bytes).decode("utf-8"),
"index": idx,
"sample_rate": sample_rate,
"duration": duration,
"done": False,
}
except Exception as e:
chunk = {
"text": sentence,
"audio": "",
"index": idx,
"error": str(e),
"done": False,
}
yield f"data: {json.dumps(chunk)}\n\n"
yield f'data: {json.dumps({"text": "", "audio": "", "index": len(sentences), "done": True})}\n\n'
yield "data: [DONE]\n\n"
return StreamingResponse(
_generate(),
media_type="text/event-stream",
headers={"X-Accel-Buffering": "no", "Cache-Control": "no-cache"},
)
# ── GET /speakers — list available speakers ──────────────────────────
@_fastapi.get("/speakers")
def list_speakers(self) -> dict[str, Any]:
"""List available speakers for multi-speaker models."""
speakers = []
if hasattr(self.tts, "speakers") and self.tts.speakers:
speakers = self.tts.speakers
return {
"model": self.model_name,
"speakers": speakers,
"is_multi_speaker": len(speakers) > 0,
}
app = TTSDeployment.bind()