- Add py.typed marker for PEP 561 type hint support - Add ray_utils module for Ray handle detection and caching - Update all clients (Embeddings, LLM, TTS, STT, Reranker) to use Ray handles when running inside Ray cluster for faster internal calls - Add .pre-commit-config.yaml with ruff and standard hooks - Add pre-commit and ray[serve] to optional dependencies - Bump ruff version to 0.4.0
150 lines
4.4 KiB
Python
150 lines
4.4 KiB
Python
"""
|
|
TTS service client (Coqui XTTS).
|
|
|
|
Supports both HTTP (external) and Ray handles (internal) for optimal performance.
|
|
"""
|
|
|
|
import logging
|
|
from typing import Any, Optional
|
|
|
|
import httpx
|
|
|
|
from handler_base.config import TTSSettings
|
|
from handler_base.ray_utils import get_ray_handle
|
|
from handler_base.telemetry import create_span
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class TTSClient:
|
|
"""
|
|
Client for the TTS service (Coqui XTTS).
|
|
|
|
When running inside Ray, automatically uses Ray handles for faster
|
|
internal communication. Falls back to HTTP for external calls.
|
|
|
|
Usage:
|
|
client = TTSClient()
|
|
audio_bytes = await client.synthesize("Hello world")
|
|
"""
|
|
|
|
# Ray Serve deployment configuration
|
|
RAY_DEPLOYMENT_NAME = "TTSDeployment"
|
|
RAY_APP_NAME = "tts"
|
|
|
|
def __init__(self, settings: Optional[TTSSettings] = None):
|
|
self.settings = settings or TTSSettings()
|
|
self._client = httpx.AsyncClient(
|
|
base_url=self.settings.tts_url,
|
|
timeout=120.0, # TTS can be slow
|
|
)
|
|
self._ray_handle: Optional[Any] = None
|
|
self._ray_checked = False
|
|
|
|
def _get_ray_handle(self) -> Optional[Any]:
|
|
"""Get Ray handle, checking only once."""
|
|
if not self._ray_checked:
|
|
self._ray_handle = get_ray_handle(self.RAY_DEPLOYMENT_NAME, self.RAY_APP_NAME)
|
|
self._ray_checked = True
|
|
return self._ray_handle
|
|
|
|
async def close(self) -> None:
|
|
"""Close the HTTP client."""
|
|
await self._client.aclose()
|
|
|
|
async def synthesize(
|
|
self,
|
|
text: str,
|
|
language: Optional[str] = None,
|
|
speaker: Optional[str] = None,
|
|
) -> bytes:
|
|
"""
|
|
Synthesize speech from text.
|
|
|
|
Args:
|
|
text: Text to synthesize
|
|
language: Language code (e.g., "en", "es", "fr")
|
|
speaker: Speaker ID or reference
|
|
|
|
Returns:
|
|
WAV audio bytes
|
|
"""
|
|
language = language or self.settings.tts_language
|
|
|
|
with create_span("tts.synthesize") as span:
|
|
if span:
|
|
span.set_attribute("tts.language", language)
|
|
span.set_attribute("tts.text_length", len(text))
|
|
|
|
params = {
|
|
"text": text,
|
|
"language_id": language,
|
|
}
|
|
if speaker:
|
|
params["speaker_id"] = speaker
|
|
|
|
# Try Ray handle first (faster internal path)
|
|
handle = self._get_ray_handle()
|
|
if handle:
|
|
try:
|
|
if span:
|
|
span.set_attribute("tts.transport", "ray")
|
|
audio_bytes = await handle.synthesize.remote(text, language, speaker)
|
|
if span:
|
|
span.set_attribute("tts.audio_size", len(audio_bytes))
|
|
return audio_bytes
|
|
except Exception as e:
|
|
logger.warning(f"Ray handle failed, falling back to HTTP: {e}")
|
|
|
|
# HTTP fallback
|
|
if span:
|
|
span.set_attribute("tts.transport", "http")
|
|
|
|
response = await self._client.get("/api/tts", params=params)
|
|
response.raise_for_status()
|
|
|
|
audio_bytes = response.content
|
|
|
|
if span:
|
|
span.set_attribute("tts.audio_size", len(audio_bytes))
|
|
|
|
return audio_bytes
|
|
|
|
async def synthesize_to_file(
|
|
self,
|
|
text: str,
|
|
output_path: str,
|
|
language: Optional[str] = None,
|
|
speaker: Optional[str] = None,
|
|
) -> None:
|
|
"""
|
|
Synthesize speech and save to a file.
|
|
|
|
Args:
|
|
text: Text to synthesize
|
|
output_path: Path to save the audio file
|
|
language: Language code
|
|
speaker: Speaker ID
|
|
"""
|
|
audio_bytes = await self.synthesize(text, language, speaker)
|
|
|
|
with open(output_path, "wb") as f:
|
|
f.write(audio_bytes)
|
|
|
|
async def get_speakers(self) -> list[dict]:
|
|
"""Get available speakers/voices."""
|
|
try:
|
|
response = await self._client.get("/api/speakers")
|
|
response.raise_for_status()
|
|
return response.json()
|
|
except Exception:
|
|
return []
|
|
|
|
async def health(self) -> bool:
|
|
"""Check if the TTS service is healthy."""
|
|
try:
|
|
response = await self._client.get("/health")
|
|
return response.status_code == 200
|
|
except Exception:
|
|
return False
|