feat: add streaming TTS service with Coqui XTTS

- tts_streaming.py: NATS-based TTS using XTTS HTTP API
- Streaming audio chunks for low-latency playback
- Voice cloning support via reference audio
- Multi-language synthesis
- OpenTelemetry instrumentation with HyperDX support
This commit is contained in:
2026-02-02 06:23:34 -05:00
parent fddec4585b
commit d4fafea09b
6 changed files with 592 additions and 1 deletions

356
tts_streaming.py Normal file
View File

@@ -0,0 +1,356 @@
#!/usr/bin/env python3
"""
Streaming TTS Service
Real-time Text-to-Speech service that processes synthesis requests from NATS:
1. Subscribe to TTS requests on "ai.voice.tts.request.{session_id}"
2. Synthesize speech using Coqui XTTS via HTTP API
3. Stream audio chunks back via "ai.voice.tts.audio.{session_id}"
4. Support for voice cloning and multi-speaker synthesis
This enables real-time voice synthesis for voice assistant applications.
"""
import asyncio
import base64
import logging
import os
import signal
import time
from typing import Dict, Optional
import httpx
import msgpack
import nats
import nats.js
from nats.aio.msg import Msg
# OpenTelemetry imports
from opentelemetry import trace, metrics
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor
from opentelemetry.sdk.metrics import MeterProvider
from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import OTLPMetricExporter
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter as OTLPSpanExporterHTTP
from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter as OTLPMetricExporterHTTP
from opentelemetry.sdk.resources import Resource, SERVICE_NAME, SERVICE_VERSION, SERVICE_NAMESPACE
from opentelemetry.instrumentation.httpx import HTTPXClientInstrumentor
from opentelemetry.instrumentation.logging import LoggingInstrumentor
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger("tts-streaming")
def setup_telemetry():
"""Initialize OpenTelemetry tracing and metrics with HyperDX support."""
otel_enabled = os.environ.get("OTEL_ENABLED", "true").lower() == "true"
if not otel_enabled:
logger.info("OpenTelemetry disabled")
return None, None
otel_endpoint = os.environ.get("OTEL_EXPORTER_OTLP_ENDPOINT", "http://opentelemetry-collector.observability.svc.cluster.local:4317")
service_name = os.environ.get("OTEL_SERVICE_NAME", "tts-streaming")
service_namespace = os.environ.get("OTEL_SERVICE_NAMESPACE", "ai-ml")
hyperdx_api_key = os.environ.get("HYPERDX_API_KEY", "")
hyperdx_endpoint = os.environ.get("HYPERDX_ENDPOINT", "https://in-otel.hyperdx.io")
use_hyperdx = os.environ.get("HYPERDX_ENABLED", "false").lower() == "true" and hyperdx_api_key
resource = Resource.create({
SERVICE_NAME: service_name,
SERVICE_VERSION: os.environ.get("SERVICE_VERSION", "1.0.0"),
SERVICE_NAMESPACE: service_namespace,
"deployment.environment": os.environ.get("DEPLOYMENT_ENV", "production"),
"host.name": os.environ.get("HOSTNAME", "unknown"),
})
trace_provider = TracerProvider(resource=resource)
if use_hyperdx:
logger.info(f"Configuring HyperDX exporter at {hyperdx_endpoint}")
headers = {"authorization": hyperdx_api_key}
otlp_span_exporter = OTLPSpanExporterHTTP(
endpoint=f"{hyperdx_endpoint}/v1/traces",
headers=headers
)
otlp_metric_exporter = OTLPMetricExporterHTTP(
endpoint=f"{hyperdx_endpoint}/v1/metrics",
headers=headers
)
else:
otlp_span_exporter = OTLPSpanExporter(endpoint=otel_endpoint, insecure=True)
otlp_metric_exporter = OTLPMetricExporter(endpoint=otel_endpoint, insecure=True)
trace_provider.add_span_processor(BatchSpanProcessor(otlp_span_exporter))
trace.set_tracer_provider(trace_provider)
metric_reader = PeriodicExportingMetricReader(otlp_metric_exporter, export_interval_millis=60000)
meter_provider = MeterProvider(resource=resource, metric_readers=[metric_reader])
metrics.set_meter_provider(meter_provider)
HTTPXClientInstrumentor().instrument()
LoggingInstrumentor().instrument(set_logging_format=True)
destination = "HyperDX" if use_hyperdx else "OTEL Collector"
logger.info(f"OpenTelemetry initialized - destination: {destination}, service: {service_name}")
return trace.get_tracer(__name__), metrics.get_meter(__name__)
# Configuration from environment
XTTS_URL = os.environ.get("XTTS_URL", "http://xtts-predictor.ai-ml.svc.cluster.local")
NATS_URL = os.environ.get("NATS_URL", "nats://nats.ai-ml.svc.cluster.local:4222")
# NATS subjects
REQUEST_SUBJECT_PREFIX = "ai.voice.tts.request" # ai.voice.tts.request.{session_id}
AUDIO_SUBJECT_PREFIX = "ai.voice.tts.audio" # ai.voice.tts.audio.{session_id}
STATUS_SUBJECT_PREFIX = "ai.voice.tts.status" # ai.voice.tts.status.{session_id}
# TTS parameters
DEFAULT_SPEAKER = os.environ.get("TTS_DEFAULT_SPEAKER", "default")
DEFAULT_LANGUAGE = os.environ.get("TTS_DEFAULT_LANGUAGE", "en")
AUDIO_CHUNK_SIZE = int(os.environ.get("TTS_AUDIO_CHUNK_SIZE", "32768")) # 32KB chunks for streaming
SAMPLE_RATE = int(os.environ.get("TTS_SAMPLE_RATE", "24000")) # XTTS default sample rate
class StreamingTTS:
"""Streaming Text-to-Speech service using Coqui XTTS."""
def __init__(self):
self.nc = None
self.js = None
self.http_client = None
self.running = True
self.is_healthy = False
self.tracer = None
self.meter = None
self.synthesis_counter = None
self.synthesis_duration = None
self.active_sessions: Dict[str, dict] = {}
async def setup(self):
"""Initialize connections."""
self.tracer, self.meter = setup_telemetry()
if self.meter:
self.synthesis_counter = self.meter.create_counter(
name="tts_synthesis_total",
description="Total number of TTS synthesis requests",
unit="1"
)
self.synthesis_duration = self.meter.create_histogram(
name="tts_synthesis_duration_seconds",
description="Duration of TTS synthesis",
unit="s"
)
# NATS connection
self.nc = await nats.connect(NATS_URL)
logger.info(f"Connected to NATS at {NATS_URL}")
# Initialize JetStream context
self.js = self.nc.jetstream()
# Create or update stream for TTS messages
try:
stream_config = nats.js.api.StreamConfig(
name="AI_VOICE_TTS",
subjects=["ai.voice.tts.>"],
retention=nats.js.api.RetentionPolicy.LIMITS,
max_age=300, # 5 minutes
storage=nats.js.api.StorageType.MEMORY,
)
await self.js.add_stream(stream_config)
logger.info("Created/updated JetStream stream: AI_VOICE_TTS")
except Exception as e:
logger.info(f"JetStream stream setup: {e}")
# HTTP client for XTTS service
self.http_client = httpx.AsyncClient(timeout=180.0)
logger.info("HTTP client initialized")
self.is_healthy = True
async def synthesize(self, text: str, speaker: str = None, language: str = None,
speaker_wav_b64: str = None) -> Optional[bytes]:
"""Synthesize speech using XTTS API."""
start_time = time.time()
try:
# Build request payload
payload = {
"text": text,
"speaker": speaker or DEFAULT_SPEAKER,
"language": language or DEFAULT_LANGUAGE,
}
# Add speaker reference audio for voice cloning if provided
if speaker_wav_b64:
payload["speaker_wav"] = speaker_wav_b64
# Call XTTS API
response = await self.http_client.post(
f"{XTTS_URL}/v1/audio/speech",
json=payload
)
response.raise_for_status()
audio_bytes = response.content
duration = time.time() - start_time
audio_duration = len(audio_bytes) / (SAMPLE_RATE * 2) # 16-bit audio
rtf = duration / audio_duration if audio_duration > 0 else 0
logger.info(f"Synthesized {len(audio_bytes)} bytes in {duration:.2f}s (RTF: {rtf:.2f})")
if self.synthesis_duration:
self.synthesis_duration.record(duration, {"speaker": speaker or DEFAULT_SPEAKER})
return audio_bytes
except Exception as e:
logger.error(f"Synthesis failed: {e}")
return None
async def stream_audio(self, session_id: str, audio_bytes: bytes):
"""Stream audio back to client in chunks."""
total_chunks = (len(audio_bytes) + AUDIO_CHUNK_SIZE - 1) // AUDIO_CHUNK_SIZE
for i in range(0, len(audio_bytes), AUDIO_CHUNK_SIZE):
chunk = audio_bytes[i:i + AUDIO_CHUNK_SIZE]
chunk_index = i // AUDIO_CHUNK_SIZE
is_last = (i + AUDIO_CHUNK_SIZE) >= len(audio_bytes)
message = {
"session_id": session_id,
"chunk_index": chunk_index,
"total_chunks": total_chunks,
"audio_b64": base64.b64encode(chunk).decode(),
"is_last": is_last,
"timestamp": time.time(),
"sample_rate": SAMPLE_RATE,
}
await self.nc.publish(
f"{AUDIO_SUBJECT_PREFIX}.{session_id}",
msgpack.packb(message)
)
logger.debug(f"Sent chunk {chunk_index + 1}/{total_chunks} for session {session_id}")
logger.info(f"Streamed {total_chunks} chunks for session {session_id}")
async def handle_request(self, msg: Msg):
"""Handle incoming TTS request."""
try:
# Extract session_id from subject: ai.voice.tts.request.{session_id}
subject_parts = msg.subject.split('.')
if len(subject_parts) < 5:
logger.warning(f"Invalid subject format: {msg.subject}")
return
session_id = subject_parts[4]
# Parse request using msgpack
data = msgpack.unpackb(msg.data, raw=False)
text = data.get("text", "")
speaker = data.get("speaker")
language = data.get("language")
speaker_wav_b64 = data.get("speaker_wav_b64") # For voice cloning
stream = data.get("stream", True) # Default to streaming
if not text:
logger.warning(f"Empty text for session {session_id}")
await self.publish_status(session_id, "error", "Empty text provided")
return
logger.info(f"Processing TTS request for session {session_id}: {text[:50]}...")
if self.synthesis_counter:
self.synthesis_counter.add(1, {"session_id": session_id})
# Publish status: processing
await self.publish_status(session_id, "processing", f"Synthesizing {len(text)} characters")
# Synthesize audio
audio_bytes = await self.synthesize(text, speaker, language, speaker_wav_b64)
if audio_bytes:
if stream:
# Stream audio in chunks
await self.stream_audio(session_id, audio_bytes)
else:
# Send complete audio in one message
message = {
"session_id": session_id,
"audio_b64": base64.b64encode(audio_bytes).decode(),
"timestamp": time.time(),
"sample_rate": SAMPLE_RATE,
}
await self.nc.publish(
f"{AUDIO_SUBJECT_PREFIX}.{session_id}",
msgpack.packb(message)
)
await self.publish_status(session_id, "completed", f"Audio size: {len(audio_bytes)} bytes")
else:
await self.publish_status(session_id, "error", "Synthesis failed")
except Exception as e:
logger.error(f"Error handling TTS request: {e}", exc_info=True)
try:
await self.publish_status(session_id, "error", str(e))
except:
pass
async def publish_status(self, session_id: str, status: str, message: str = ""):
"""Publish TTS status update."""
status_msg = {
"session_id": session_id,
"status": status,
"message": message,
"timestamp": time.time(),
}
await self.nc.publish(
f"{STATUS_SUBJECT_PREFIX}.{session_id}",
msgpack.packb(status_msg)
)
logger.debug(f"Published status '{status}' for session {session_id}")
async def run(self):
"""Main run loop."""
await self.setup()
# Subscribe to TTS requests
sub = await self.nc.subscribe(f"{REQUEST_SUBJECT_PREFIX}.>", cb=self.handle_request)
logger.info(f"Subscribed to {REQUEST_SUBJECT_PREFIX}.>")
# Handle shutdown
def signal_handler():
self.running = False
loop = asyncio.get_event_loop()
for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, signal_handler)
# Keep running
while self.running:
await asyncio.sleep(1)
# Cleanup
logger.info("Shutting down...")
await sub.unsubscribe()
await self.nc.close()
await self.http_client.aclose()
logger.info("Shutdown complete")
if __name__ == "__main__":
service = StreamingTTS()
asyncio.run(service.run())