""" Ray Serve deployment for Coqui TTS. Runs on: elminster (RTX 2070 8GB, CUDA) Provides three API styles: POST /tts — JSON body → JSON response with base64 audio POST /tts/stream — JSON body → SSE with per-sentence base64 audio chunks GET /tts/api/tts — Coqui-compatible query params → raw WAV bytes """ import base64 import io import json import os import re import time from typing import Any from fastapi import FastAPI, Query, Request from fastapi.responses import Response, StreamingResponse from ray import serve try: from ray_serve.mlflow_logger import InferenceLogger except ImportError: InferenceLogger = None _fastapi = FastAPI() # ── Sentence splitting for streaming ───────────────────────────────────── _SENTENCE_RE = re.compile(r"(?<=[.!?;])\s+|(?<=\n)\s*", re.MULTILINE) def _split_sentences(text: str, max_len: int = 200) -> list[str]: """Split text into sentences suitable for per-sentence TTS streaming.""" text = re.sub(r"[ \t]+", " ", text).strip() if not text: return [] raw_parts = _SENTENCE_RE.split(text) sentences: list[str] = [] for part in raw_parts: part = part.strip() if not part: continue if len(part) > max_len: for sp in re.split(r"(?<=[,;:])\s+", part): sp = sp.strip() if sp: sentences.append(sp) else: sentences.append(part) return sentences @serve.deployment(name="TTSDeployment", num_replicas=1) @serve.ingress(_fastapi) class TTSDeployment: def __init__(self): import torch from TTS.api import TTS self.model_name = os.environ.get("MODEL_NAME", "tts_models/en/ljspeech/tacotron2-DDC") # Detect device self.use_gpu = torch.cuda.is_available() print(f"Loading TTS model: {self.model_name}") print(f"Using GPU: {self.use_gpu}") self.tts = TTS(model_name=self.model_name, progress_bar=False) if self.use_gpu: self.tts = self.tts.to("cuda") print("TTS model loaded successfully") # MLflow metrics if InferenceLogger is not None: self._mlflow = InferenceLogger( experiment_name="ray-serve-tts", run_name=f"tts-{self.model_name.split('/')[-1]}", tags={"model.name": self.model_name, "model.framework": "coqui-tts", "gpu": str(self.use_gpu)}, flush_every=5, ) self._mlflow.initialize(params={"model_name": self.model_name, "use_gpu": str(self.use_gpu)}) else: self._mlflow = None # ── internal synthesis helpers ──────────────────────────────────────── def _synthesize(self, text: str, speaker: str | None = None, language: str | None = None, speed: float = 1.0): """Return (wav_bytes: bytes, sample_rate: int, duration: float).""" import numpy as np from scipy.io import wavfile wav = self.tts.tts(text=text, speaker=speaker, language=language, speed=speed) if not isinstance(wav, np.ndarray): wav = np.array(wav) wav_int16 = (wav * 32767).astype(np.int16) sample_rate = ( self.tts.synthesizer.output_sample_rate if hasattr(self.tts, "synthesizer") else 22050 ) buf = io.BytesIO() wavfile.write(buf, sample_rate, wav_int16) return buf.getvalue(), sample_rate, len(wav) / sample_rate def _log(self, start: float, duration: float, text_len: int): if self._mlflow: elapsed = time.time() - start self._mlflow.log_request( latency_s=elapsed, audio_duration_s=duration, text_chars=text_len, realtime_factor=elapsed / duration if duration > 0 else 0, ) # ── GET /health — simple liveness check ───────────────────────────── @_fastapi.get("/health") def health(self) -> dict[str, Any]: """Simple health/readiness check.""" return { "status": "ok", "model": self.model_name, "gpu": self.use_gpu, } # ── POST / — JSON API (base64 audio in response) ──────────────────── @_fastapi.post("/") async def generate_json(self, request: dict[str, Any]) -> dict[str, Any]: """ JSON API — POST body: {"text": "...", "speaker": "...", "language": "en", "speed": 1.0, "output_format": "wav", "return_base64": true} """ _start = time.time() text = request.get("text", "") if not text: return {"error": "No text provided"} speaker = request.get("speaker") language = request.get("language") speed = request.get("speed", 1.0) output_format = request.get("output_format", "wav") return_base64 = request.get("return_base64", True) # Only pass language/speaker if the model supports it if not (hasattr(self.tts, "is_multi_lingual") and self.tts.is_multi_lingual): language = None if not (hasattr(self.tts, "is_multi_speaker") and self.tts.is_multi_speaker): speaker = None try: audio_bytes, sample_rate, duration = self._synthesize( text, speaker, language, speed ) self._log(_start, duration, len(text)) resp: dict[str, Any] = { "model": self.model_name, "sample_rate": sample_rate, "duration": duration, "format": output_format, } if return_base64: resp["audio"] = base64.b64encode(audio_bytes).decode("utf-8") else: resp["audio_bytes"] = audio_bytes return resp except Exception as e: return {"error": str(e), "model": self.model_name} # ── GET /api/tts — Coqui-compatible raw WAV endpoint ───────────────── @_fastapi.get("/api/tts") async def generate_raw( self, text: str = Query(..., description="Text to synthesize"), language_id: str = Query("en", description="Language code"), speaker_id: str | None = Query(None, description="Speaker name"), ) -> Response: """Coqui XTTS-compatible endpoint — returns raw WAV bytes.""" _start = time.time() if not text: return Response(content="text parameter required", status_code=400) # Only pass language/speaker if the model is multi-lingual/multi-speaker lang = language_id if hasattr(self.tts, "is_multi_lingual") and self.tts.is_multi_lingual else None spk = speaker_id if hasattr(self.tts, "is_multi_speaker") and self.tts.is_multi_speaker else None try: audio_bytes, _sr, duration = self._synthesize( text, spk, lang ) self._log(_start, duration, len(text)) return Response(content=audio_bytes, media_type="audio/wav") except Exception as e: return Response(content=str(e), status_code=500) # ── POST /stream — SSE per-sentence streaming ────────────────────── @_fastapi.post("/stream") async def generate_stream(self, request: Request) -> StreamingResponse: """Stream TTS audio as SSE events, one per sentence. Request body: {"text": "...", "language": "en", "speaker": null, "speed": 1.0} Each SSE event is a JSON object: data: {"text": "sentence", "audio": "", "index": 0, "sample_rate": 24000, "duration": 1.23, "done": false} Final event: data: {"text": "", "audio": "", "index": N, "done": true} followed by: data: [DONE] """ body = await request.json() text = body.get("text", "") if not text: async def _empty(): yield 'data: {"error": "No text provided"}\n\n' yield "data: [DONE]\n\n" return StreamingResponse(_empty(), media_type="text/event-stream") speaker = body.get("speaker") language = body.get("language") speed = body.get("speed", 1.0) # Only pass language/speaker if the model supports it if not (hasattr(self.tts, "is_multi_lingual") and self.tts.is_multi_lingual): language = None if not (hasattr(self.tts, "is_multi_speaker") and self.tts.is_multi_speaker): speaker = None sentences = _split_sentences(text) async def _generate(): for idx, sentence in enumerate(sentences): _start = time.time() try: audio_bytes, sample_rate, duration = self._synthesize( sentence, speaker, language, speed ) self._log(_start, duration, len(sentence)) chunk = { "text": sentence, "audio": base64.b64encode(audio_bytes).decode("utf-8"), "index": idx, "sample_rate": sample_rate, "duration": duration, "done": False, } except Exception as e: chunk = { "text": sentence, "audio": "", "index": idx, "error": str(e), "done": False, } yield f"data: {json.dumps(chunk)}\n\n" yield f'data: {json.dumps({"text": "", "audio": "", "index": len(sentences), "done": True})}\n\n' yield "data: [DONE]\n\n" return StreamingResponse( _generate(), media_type="text/event-stream", headers={"X-Accel-Buffering": "no", "Cache-Control": "no-cache"}, ) # ── GET /speakers — list available speakers ────────────────────────── @_fastapi.get("/speakers") def list_speakers(self) -> dict[str, Any]: """List available speakers for multi-speaker models.""" speakers = [] if hasattr(self.tts, "speakers") and self.tts.speakers: speakers = self.tts.speakers return { "model": self.model_name, "speakers": speakers, "is_multi_speaker": len(speakers) > 0, } app = TTSDeployment.bind()