""" Ray Serve deployment for Coqui TTS. Runs on: elminster (RTX 2070 8GB, CUDA) Provides three API styles: POST /tts — JSON body → JSON response with base64 audio POST /tts/stream — JSON body → SSE with per-sentence base64 audio chunks GET /tts/api/tts — Coqui-compatible query params → raw WAV bytes """ import base64 import io import json import os import re import time from typing import Any from fastapi import FastAPI, Query, Request from fastapi.responses import Response, StreamingResponse from ray import serve try: from ray_serve.mlflow_logger import InferenceLogger except ImportError: InferenceLogger = None _fastapi = FastAPI() # ── Sentence splitting for streaming ───────────────────────────────────── _SENTENCE_RE = re.compile(r"(?<=[.!?;])\s+|(?<=\n)\s*", re.MULTILINE) def _split_sentences(text: str, max_len: int = 200) -> list[str]: """Split text into sentences suitable for per-sentence TTS streaming.""" text = re.sub(r"[ \t]+", " ", text).strip() if not text: return [] raw_parts = _SENTENCE_RE.split(text) sentences: list[str] = [] for part in raw_parts: part = part.strip() if not part: continue if len(part) > max_len: for sp in re.split(r"(?<=[,;:])\s+", part): sp = sp.strip() if sp: sentences.append(sp) else: sentences.append(part) return sentences @serve.deployment(name="TTSDeployment", num_replicas=1) @serve.ingress(_fastapi) class TTSDeployment: def __init__(self): import torch from TTS.api import TTS self.model_name = os.environ.get("MODEL_NAME", "tts_models/en/ljspeech/tacotron2-DDC") # Detect device self.use_gpu = torch.cuda.is_available() print(f"Loading TTS model: {self.model_name}") print(f"Using GPU: {self.use_gpu}") self.tts = TTS(model_name=self.model_name, progress_bar=False) if self.use_gpu: self.tts = self.tts.to("cuda") print("TTS model loaded successfully") # MLflow metrics if InferenceLogger is not None: self._mlflow = InferenceLogger( experiment_name="ray-serve-tts", run_name=f"tts-{self.model_name.split('/')[-1]}", tags={"model.name": self.model_name, "model.framework": "coqui-tts", "gpu": str(self.use_gpu)}, flush_every=5, ) self._mlflow.initialize(params={"model_name": self.model_name, "use_gpu": str(self.use_gpu)}) else: self._mlflow = None # ── internal synthesis helpers ──────────────────────────────────────── def _synthesize(self, text: str, speaker: str | None = None, language: str | None = None, speed: float = 1.0): """Return (wav_bytes: bytes, sample_rate: int, duration: float).""" import numpy as np from scipy.io import wavfile wav = self.tts.tts(text=text, speaker=speaker, language=language, speed=speed) if not isinstance(wav, np.ndarray): wav = np.array(wav) wav_int16 = (wav * 32767).astype(np.int16) sample_rate = ( self.tts.synthesizer.output_sample_rate if hasattr(self.tts, "synthesizer") else 22050 ) buf = io.BytesIO() wavfile.write(buf, sample_rate, wav_int16) return buf.getvalue(), sample_rate, len(wav) / sample_rate def _log(self, start: float, duration: float, text_len: int): if self._mlflow: elapsed = time.time() - start self._mlflow.log_request( latency_s=elapsed, audio_duration_s=duration, text_chars=text_len, realtime_factor=elapsed / duration if duration > 0 else 0, ) # ── POST / — JSON API (base64 audio in response) ──────────────────── @_fastapi.post("/") async def generate_json(self, request: dict[str, Any]) -> dict[str, Any]: """ JSON API — POST body: {"text": "...", "speaker": "...", "language": "en", "speed": 1.0, "output_format": "wav", "return_base64": true} """ _start = time.time() text = request.get("text", "") if not text: return {"error": "No text provided"} speaker = request.get("speaker") language = request.get("language") speed = request.get("speed", 1.0) output_format = request.get("output_format", "wav") return_base64 = request.get("return_base64", True) try: audio_bytes, sample_rate, duration = self._synthesize( text, speaker, language, speed ) self._log(_start, duration, len(text)) resp: dict[str, Any] = { "model": self.model_name, "sample_rate": sample_rate, "duration": duration, "format": output_format, } if return_base64: resp["audio"] = base64.b64encode(audio_bytes).decode("utf-8") else: resp["audio_bytes"] = audio_bytes return resp except Exception as e: return {"error": str(e), "model": self.model_name} # ── GET /api/tts — Coqui-compatible raw WAV endpoint ───────────────── @_fastapi.get("/api/tts") async def generate_raw( self, text: str = Query(..., description="Text to synthesize"), language_id: str = Query("en", description="Language code"), speaker_id: str | None = Query(None, description="Speaker name"), ) -> Response: """Coqui XTTS-compatible endpoint — returns raw WAV bytes.""" _start = time.time() if not text: return Response(content="text parameter required", status_code=400) try: audio_bytes, _sr, duration = self._synthesize( text, speaker_id, language_id ) self._log(_start, duration, len(text)) return Response(content=audio_bytes, media_type="audio/wav") except Exception as e: return Response(content=str(e), status_code=500) # ── POST /stream — SSE per-sentence streaming ────────────────────── @_fastapi.post("/stream") async def generate_stream(self, request: Request) -> StreamingResponse: """Stream TTS audio as SSE events, one per sentence. Request body: {"text": "...", "language": "en", "speaker": null, "speed": 1.0} Each SSE event is a JSON object: data: {"text": "sentence", "audio": "", "index": 0, "sample_rate": 24000, "duration": 1.23, "done": false} Final event: data: {"text": "", "audio": "", "index": N, "done": true} followed by: data: [DONE] """ body = await request.json() text = body.get("text", "") if not text: async def _empty(): yield 'data: {"error": "No text provided"}\n\n' yield "data: [DONE]\n\n" return StreamingResponse(_empty(), media_type="text/event-stream") speaker = body.get("speaker") language = body.get("language") speed = body.get("speed", 1.0) sentences = _split_sentences(text) async def _generate(): for idx, sentence in enumerate(sentences): _start = time.time() try: audio_bytes, sample_rate, duration = self._synthesize( sentence, speaker, language, speed ) self._log(_start, duration, len(sentence)) chunk = { "text": sentence, "audio": base64.b64encode(audio_bytes).decode("utf-8"), "index": idx, "sample_rate": sample_rate, "duration": duration, "done": False, } except Exception as e: chunk = { "text": sentence, "audio": "", "index": idx, "error": str(e), "done": False, } yield f"data: {json.dumps(chunk)}\n\n" yield f'data: {json.dumps({"text": "", "audio": "", "index": len(sentences), "done": True})}\n\n' yield "data: [DONE]\n\n" return StreamingResponse( _generate(), media_type="text/event-stream", headers={"X-Accel-Buffering": "no", "Cache-Control": "no-cache"}, ) # ── GET /speakers — list available speakers ────────────────────────── @_fastapi.get("/speakers") def list_speakers(self) -> dict[str, Any]: """List available speakers for multi-speaker models.""" speakers = [] if hasattr(self.tts, "speakers") and self.tts.speakers: speakers = self.tts.speakers return { "model": self.model_name, "speakers": speakers, "is_multi_speaker": len(speakers) > 0, } app = TTSDeployment.bind()