#!/usr/bin/env python3 """ Streaming TTS Service Real-time Text-to-Speech service that processes synthesis requests from NATS: 1. Subscribe to TTS requests on "ai.voice.tts.request.{session_id}" 2. Synthesize speech using Coqui XTTS via HTTP API 3. Stream audio chunks back via "ai.voice.tts.audio.{session_id}" 4. Support for voice cloning and multi-speaker synthesis This enables real-time voice synthesis for voice assistant applications. """ import asyncio import base64 import logging import os import signal import time from typing import Dict, Optional import httpx import msgpack import nats import nats.js from nats.aio.msg import Msg # OpenTelemetry imports from opentelemetry import trace, metrics from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import BatchSpanProcessor from opentelemetry.sdk.metrics import MeterProvider from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import OTLPMetricExporter from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter as OTLPSpanExporterHTTP from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter as OTLPMetricExporterHTTP from opentelemetry.sdk.resources import Resource, SERVICE_NAME, SERVICE_VERSION, SERVICE_NAMESPACE from opentelemetry.instrumentation.httpx import HTTPXClientInstrumentor from opentelemetry.instrumentation.logging import LoggingInstrumentor # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger("tts-streaming") def setup_telemetry(): """Initialize OpenTelemetry tracing and metrics with HyperDX support.""" otel_enabled = os.environ.get("OTEL_ENABLED", "true").lower() == "true" if not otel_enabled: logger.info("OpenTelemetry disabled") return None, None otel_endpoint = os.environ.get("OTEL_EXPORTER_OTLP_ENDPOINT", "http://opentelemetry-collector.observability.svc.cluster.local:4317") service_name = os.environ.get("OTEL_SERVICE_NAME", "tts-streaming") service_namespace = os.environ.get("OTEL_SERVICE_NAMESPACE", "ai-ml") hyperdx_api_key = os.environ.get("HYPERDX_API_KEY", "") hyperdx_endpoint = os.environ.get("HYPERDX_ENDPOINT", "https://in-otel.hyperdx.io") use_hyperdx = os.environ.get("HYPERDX_ENABLED", "false").lower() == "true" and hyperdx_api_key resource = Resource.create({ SERVICE_NAME: service_name, SERVICE_VERSION: os.environ.get("SERVICE_VERSION", "1.0.0"), SERVICE_NAMESPACE: service_namespace, "deployment.environment": os.environ.get("DEPLOYMENT_ENV", "production"), "host.name": os.environ.get("HOSTNAME", "unknown"), }) trace_provider = TracerProvider(resource=resource) if use_hyperdx: logger.info(f"Configuring HyperDX exporter at {hyperdx_endpoint}") headers = {"authorization": hyperdx_api_key} otlp_span_exporter = OTLPSpanExporterHTTP( endpoint=f"{hyperdx_endpoint}/v1/traces", headers=headers ) otlp_metric_exporter = OTLPMetricExporterHTTP( endpoint=f"{hyperdx_endpoint}/v1/metrics", headers=headers ) else: otlp_span_exporter = OTLPSpanExporter(endpoint=otel_endpoint, insecure=True) otlp_metric_exporter = OTLPMetricExporter(endpoint=otel_endpoint, insecure=True) trace_provider.add_span_processor(BatchSpanProcessor(otlp_span_exporter)) trace.set_tracer_provider(trace_provider) metric_reader = PeriodicExportingMetricReader(otlp_metric_exporter, export_interval_millis=60000) meter_provider = MeterProvider(resource=resource, metric_readers=[metric_reader]) metrics.set_meter_provider(meter_provider) HTTPXClientInstrumentor().instrument() LoggingInstrumentor().instrument(set_logging_format=True) destination = "HyperDX" if use_hyperdx else "OTEL Collector" logger.info(f"OpenTelemetry initialized - destination: {destination}, service: {service_name}") return trace.get_tracer(__name__), metrics.get_meter(__name__) # Configuration from environment XTTS_URL = os.environ.get("XTTS_URL", "http://xtts-predictor.ai-ml.svc.cluster.local") NATS_URL = os.environ.get("NATS_URL", "nats://nats.ai-ml.svc.cluster.local:4222") # NATS subjects REQUEST_SUBJECT_PREFIX = "ai.voice.tts.request" # ai.voice.tts.request.{session_id} AUDIO_SUBJECT_PREFIX = "ai.voice.tts.audio" # ai.voice.tts.audio.{session_id} STATUS_SUBJECT_PREFIX = "ai.voice.tts.status" # ai.voice.tts.status.{session_id} # TTS parameters DEFAULT_SPEAKER = os.environ.get("TTS_DEFAULT_SPEAKER", "default") DEFAULT_LANGUAGE = os.environ.get("TTS_DEFAULT_LANGUAGE", "en") AUDIO_CHUNK_SIZE = int(os.environ.get("TTS_AUDIO_CHUNK_SIZE", "32768")) # 32KB chunks for streaming SAMPLE_RATE = int(os.environ.get("TTS_SAMPLE_RATE", "24000")) # XTTS default sample rate class StreamingTTS: """Streaming Text-to-Speech service using Coqui XTTS.""" def __init__(self): self.nc = None self.js = None self.http_client = None self.running = True self.is_healthy = False self.tracer = None self.meter = None self.synthesis_counter = None self.synthesis_duration = None self.active_sessions: Dict[str, dict] = {} async def setup(self): """Initialize connections.""" self.tracer, self.meter = setup_telemetry() if self.meter: self.synthesis_counter = self.meter.create_counter( name="tts_synthesis_total", description="Total number of TTS synthesis requests", unit="1" ) self.synthesis_duration = self.meter.create_histogram( name="tts_synthesis_duration_seconds", description="Duration of TTS synthesis", unit="s" ) # NATS connection self.nc = await nats.connect(NATS_URL) logger.info(f"Connected to NATS at {NATS_URL}") # Initialize JetStream context self.js = self.nc.jetstream() # Create or update stream for TTS messages try: stream_config = nats.js.api.StreamConfig( name="AI_VOICE_TTS", subjects=["ai.voice.tts.>"], retention=nats.js.api.RetentionPolicy.LIMITS, max_age=300, # 5 minutes storage=nats.js.api.StorageType.MEMORY, ) await self.js.add_stream(stream_config) logger.info("Created/updated JetStream stream: AI_VOICE_TTS") except Exception as e: logger.info(f"JetStream stream setup: {e}") # HTTP client for XTTS service self.http_client = httpx.AsyncClient(timeout=180.0) logger.info("HTTP client initialized") self.is_healthy = True async def synthesize(self, text: str, speaker: str = None, language: str = None, speaker_wav_b64: str = None) -> Optional[bytes]: """Synthesize speech using XTTS API.""" start_time = time.time() try: # Build request payload payload = { "text": text, "speaker": speaker or DEFAULT_SPEAKER, "language": language or DEFAULT_LANGUAGE, } # Add speaker reference audio for voice cloning if provided if speaker_wav_b64: payload["speaker_wav"] = speaker_wav_b64 # Call XTTS API response = await self.http_client.post( f"{XTTS_URL}/v1/audio/speech", json=payload ) response.raise_for_status() audio_bytes = response.content duration = time.time() - start_time audio_duration = len(audio_bytes) / (SAMPLE_RATE * 2) # 16-bit audio rtf = duration / audio_duration if audio_duration > 0 else 0 logger.info(f"Synthesized {len(audio_bytes)} bytes in {duration:.2f}s (RTF: {rtf:.2f})") if self.synthesis_duration: self.synthesis_duration.record(duration, {"speaker": speaker or DEFAULT_SPEAKER}) return audio_bytes except Exception as e: logger.error(f"Synthesis failed: {e}") return None async def stream_audio(self, session_id: str, audio_bytes: bytes): """Stream audio back to client in chunks.""" total_chunks = (len(audio_bytes) + AUDIO_CHUNK_SIZE - 1) // AUDIO_CHUNK_SIZE for i in range(0, len(audio_bytes), AUDIO_CHUNK_SIZE): chunk = audio_bytes[i:i + AUDIO_CHUNK_SIZE] chunk_index = i // AUDIO_CHUNK_SIZE is_last = (i + AUDIO_CHUNK_SIZE) >= len(audio_bytes) message = { "session_id": session_id, "chunk_index": chunk_index, "total_chunks": total_chunks, "audio_b64": base64.b64encode(chunk).decode(), "is_last": is_last, "timestamp": time.time(), "sample_rate": SAMPLE_RATE, } await self.nc.publish( f"{AUDIO_SUBJECT_PREFIX}.{session_id}", msgpack.packb(message) ) logger.debug(f"Sent chunk {chunk_index + 1}/{total_chunks} for session {session_id}") logger.info(f"Streamed {total_chunks} chunks for session {session_id}") async def handle_request(self, msg: Msg): """Handle incoming TTS request.""" try: # Extract session_id from subject: ai.voice.tts.request.{session_id} subject_parts = msg.subject.split('.') if len(subject_parts) < 5: logger.warning(f"Invalid subject format: {msg.subject}") return session_id = subject_parts[4] # Parse request using msgpack data = msgpack.unpackb(msg.data, raw=False) text = data.get("text", "") speaker = data.get("speaker") language = data.get("language") speaker_wav_b64 = data.get("speaker_wav_b64") # For voice cloning stream = data.get("stream", True) # Default to streaming if not text: logger.warning(f"Empty text for session {session_id}") await self.publish_status(session_id, "error", "Empty text provided") return logger.info(f"Processing TTS request for session {session_id}: {text[:50]}...") if self.synthesis_counter: self.synthesis_counter.add(1, {"session_id": session_id}) # Publish status: processing await self.publish_status(session_id, "processing", f"Synthesizing {len(text)} characters") # Synthesize audio audio_bytes = await self.synthesize(text, speaker, language, speaker_wav_b64) if audio_bytes: if stream: # Stream audio in chunks await self.stream_audio(session_id, audio_bytes) else: # Send complete audio in one message message = { "session_id": session_id, "audio_b64": base64.b64encode(audio_bytes).decode(), "timestamp": time.time(), "sample_rate": SAMPLE_RATE, } await self.nc.publish( f"{AUDIO_SUBJECT_PREFIX}.{session_id}", msgpack.packb(message) ) await self.publish_status(session_id, "completed", f"Audio size: {len(audio_bytes)} bytes") else: await self.publish_status(session_id, "error", "Synthesis failed") except Exception as e: logger.error(f"Error handling TTS request: {e}", exc_info=True) try: await self.publish_status(session_id, "error", str(e)) except: pass async def publish_status(self, session_id: str, status: str, message: str = ""): """Publish TTS status update.""" status_msg = { "session_id": session_id, "status": status, "message": message, "timestamp": time.time(), } await self.nc.publish( f"{STATUS_SUBJECT_PREFIX}.{session_id}", msgpack.packb(status_msg) ) logger.debug(f"Published status '{status}' for session {session_id}") async def run(self): """Main run loop.""" await self.setup() # Subscribe to TTS requests sub = await self.nc.subscribe(f"{REQUEST_SUBJECT_PREFIX}.>", cb=self.handle_request) logger.info(f"Subscribed to {REQUEST_SUBJECT_PREFIX}.>") # Handle shutdown def signal_handler(): self.running = False loop = asyncio.get_event_loop() for sig in (signal.SIGTERM, signal.SIGINT): loop.add_signal_handler(sig, signal_handler) # Keep running while self.running: await asyncio.sleep(1) # Cleanup logger.info("Shutting down...") await sub.unsubscribe() await self.nc.close() await self.http_client.aclose() logger.info("Shutdown complete") if __name__ == "__main__": service = StreamingTTS() asyncio.run(service.run())