#!/usr/bin/env python3 """ Streaming TTS Service Real-time Text-to-Speech service that processes synthesis requests from NATS: 1. Subscribe to TTS requests on "ai.voice.tts.request.{session_id}" 2. Synthesize speech using Coqui XTTS via HTTP API 3. Stream audio chunks back via "ai.voice.tts.audio.{session_id}" 4. Support for voice cloning and multi-speaker synthesis This enables real-time voice synthesis for voice assistant applications. """ import asyncio import base64 import contextlib import json import logging import os import signal import time from dataclasses import dataclass from pathlib import Path import httpx import msgpack import nats import nats.js from nats.aio.msg import Msg # OpenTelemetry imports from opentelemetry import metrics, trace from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import OTLPMetricExporter from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter from opentelemetry.exporter.otlp.proto.http.metric_exporter import ( OTLPMetricExporter as OTLPMetricExporterHTTP, ) from opentelemetry.exporter.otlp.proto.http.trace_exporter import ( OTLPSpanExporter as OTLPSpanExporterHTTP, ) from opentelemetry.instrumentation.httpx import HTTPXClientInstrumentor from opentelemetry.instrumentation.logging import LoggingInstrumentor from opentelemetry.sdk.metrics import MeterProvider from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader from opentelemetry.sdk.resources import SERVICE_NAME, SERVICE_NAMESPACE, SERVICE_VERSION, Resource from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import BatchSpanProcessor # Configure logging logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" ) logger = logging.getLogger("tts-streaming") def setup_telemetry(): """Initialize OpenTelemetry tracing and metrics with HyperDX support.""" otel_enabled = os.environ.get("OTEL_ENABLED", "true").lower() == "true" if not otel_enabled: logger.info("OpenTelemetry disabled") return None, None otel_endpoint = os.environ.get( "OTEL_EXPORTER_OTLP_ENDPOINT", "http://opentelemetry-collector.observability.svc.cluster.local:4317", ) service_name = os.environ.get("OTEL_SERVICE_NAME", "tts-streaming") service_namespace = os.environ.get("OTEL_SERVICE_NAMESPACE", "ai-ml") hyperdx_api_key = os.environ.get("HYPERDX_API_KEY", "") hyperdx_endpoint = os.environ.get("HYPERDX_ENDPOINT", "https://in-otel.hyperdx.io") use_hyperdx = os.environ.get("HYPERDX_ENABLED", "false").lower() == "true" and hyperdx_api_key resource = Resource.create( { SERVICE_NAME: service_name, SERVICE_VERSION: os.environ.get("SERVICE_VERSION", "1.0.0"), SERVICE_NAMESPACE: service_namespace, "deployment.environment": os.environ.get("DEPLOYMENT_ENV", "production"), "host.name": os.environ.get("HOSTNAME", "unknown"), } ) trace_provider = TracerProvider(resource=resource) if use_hyperdx: logger.info(f"Configuring HyperDX exporter at {hyperdx_endpoint}") headers = {"authorization": hyperdx_api_key} otlp_span_exporter = OTLPSpanExporterHTTP( endpoint=f"{hyperdx_endpoint}/v1/traces", headers=headers ) otlp_metric_exporter = OTLPMetricExporterHTTP( endpoint=f"{hyperdx_endpoint}/v1/metrics", headers=headers ) else: otlp_span_exporter = OTLPSpanExporter(endpoint=otel_endpoint, insecure=True) otlp_metric_exporter = OTLPMetricExporter(endpoint=otel_endpoint, insecure=True) trace_provider.add_span_processor(BatchSpanProcessor(otlp_span_exporter)) trace.set_tracer_provider(trace_provider) metric_reader = PeriodicExportingMetricReader( otlp_metric_exporter, export_interval_millis=60000 ) meter_provider = MeterProvider(resource=resource, metric_readers=[metric_reader]) metrics.set_meter_provider(meter_provider) HTTPXClientInstrumentor().instrument() LoggingInstrumentor().instrument(set_logging_format=True) destination = "HyperDX" if use_hyperdx else "OTEL Collector" logger.info(f"OpenTelemetry initialized - destination: {destination}, service: {service_name}") return trace.get_tracer(__name__), metrics.get_meter(__name__) # Configuration from environment XTTS_URL = os.environ.get("XTTS_URL", "http://xtts-predictor.ai-ml.svc.cluster.local") NATS_URL = os.environ.get("NATS_URL", "nats://nats.ai-ml.svc.cluster.local:4222") # NATS subjects REQUEST_SUBJECT_PREFIX = "ai.voice.tts.request" # ai.voice.tts.request.{session_id} AUDIO_SUBJECT_PREFIX = "ai.voice.tts.audio" # ai.voice.tts.audio.{session_id} STATUS_SUBJECT_PREFIX = "ai.voice.tts.status" # ai.voice.tts.status.{session_id} VOICES_LIST_SUBJECT = "ai.voice.tts.voices.list" # List available voices VOICES_REFRESH_SUBJECT = "ai.voice.tts.voices.refresh" # Trigger registry refresh # TTS parameters DEFAULT_SPEAKER = os.environ.get("TTS_DEFAULT_SPEAKER", "default") DEFAULT_LANGUAGE = os.environ.get("TTS_DEFAULT_LANGUAGE", "en") AUDIO_CHUNK_SIZE = int(os.environ.get("TTS_AUDIO_CHUNK_SIZE", "32768")) # 32KB chunks for streaming SAMPLE_RATE = int(os.environ.get("TTS_SAMPLE_RATE", "24000")) # XTTS default sample rate # Custom voice model store (populated by coqui-voice-training Argo workflow) VOICE_MODEL_STORE = os.environ.get("VOICE_MODEL_STORE", "/models/tts/custom") VOICE_REGISTRY_REFRESH_SECONDS = int(os.environ.get("VOICE_REGISTRY_REFRESH_SECONDS", "300")) @dataclass class CustomVoice: """A custom trained voice produced by the coqui-voice-training pipeline.""" name: str model_path: str config_path: str created_at: str language: str = "en" model_type: str = "coqui-tts" class VoiceRegistry: """Registry of custom trained voices discovered from the model store. Scans ``VOICE_MODEL_STORE`` for directories produced by the ``coqui-voice-training`` Argo workflow. Each directory must contain ``model_info.json`` and ``model.pth``. """ def __init__(self, model_store_path: str) -> None: self.model_store = Path(model_store_path) self.voices: dict[str, CustomVoice] = {} self._last_refresh: float = 0.0 def refresh(self) -> int: """Scan the model store for available custom voices. Returns: Number of voices discovered. """ if not self.model_store.exists(): logger.warning(f"Voice model store not found: {self.model_store}") return 0 discovered: dict[str, CustomVoice] = {} for voice_dir in self.model_store.iterdir(): if not voice_dir.is_dir(): continue model_info_path = voice_dir / "model_info.json" if not model_info_path.exists(): continue try: with open(model_info_path) as f: info = json.load(f) model_path = voice_dir / "model.pth" config_path = voice_dir / "config.json" if not model_path.exists(): logger.warning(f"Model file missing for voice: {voice_dir.name}") continue voice = CustomVoice( name=info.get("name", voice_dir.name), model_path=str(model_path), config_path=str(config_path) if config_path.exists() else "", created_at=info.get("created_at", ""), language=info.get("language", "en"), model_type=info.get("type", "coqui-tts"), ) discovered[voice.name] = voice except Exception as e: logger.error(f"Failed to load voice info from {voice_dir}: {e}") continue added = set(discovered) - set(self.voices) removed = set(self.voices) - set(discovered) if added: logger.info(f"New voices discovered: {', '.join(sorted(added))}") if removed: logger.info(f"Voices removed: {', '.join(sorted(removed))}") self.voices = discovered self._last_refresh = time.time() logger.info(f"Voice registry refreshed: {len(self.voices)} custom voice(s) available") return len(self.voices) def get(self, name: str) -> CustomVoice | None: """Get a custom voice by name.""" return self.voices.get(name) def list_voices(self) -> list[dict]: """List all available custom voices as serialisable dicts.""" return [ { "name": v.name, "language": v.language, "model_type": v.model_type, "created_at": v.created_at, } for v in self.voices.values() ] class StreamingTTS: """Streaming Text-to-Speech service using Coqui XTTS.""" def __init__(self): self.nc = None self.js = None self.http_client = None self.running = True self.is_healthy = False self.tracer = None self.meter = None self.synthesis_counter = None self.synthesis_duration = None self.active_sessions: dict[str, dict] = {} self.voice_registry = VoiceRegistry(VOICE_MODEL_STORE) async def setup(self): """Initialize connections.""" self.tracer, self.meter = setup_telemetry() if self.meter: self.synthesis_counter = self.meter.create_counter( name="tts_synthesis_total", description="Total number of TTS synthesis requests", unit="1", ) self.synthesis_duration = self.meter.create_histogram( name="tts_synthesis_duration_seconds", description="Duration of TTS synthesis", unit="s", ) # NATS connection self.nc = await nats.connect(NATS_URL) logger.info(f"Connected to NATS at {NATS_URL}") # Initialize JetStream context self.js = self.nc.jetstream() # Create or update stream for TTS messages try: stream_config = nats.js.api.StreamConfig( name="AI_VOICE_TTS", subjects=["ai.voice.tts.>"], retention=nats.js.api.RetentionPolicy.LIMITS, max_age=300, # 5 minutes storage=nats.js.api.StorageType.MEMORY, ) await self.js.add_stream(stream_config) logger.info("Created/updated JetStream stream: AI_VOICE_TTS") except Exception as e: logger.info(f"JetStream stream setup: {e}") # HTTP client for XTTS service self.http_client = httpx.AsyncClient(timeout=180.0) logger.info("HTTP client initialized") # Discover custom voices from model store self.voice_registry.refresh() self.is_healthy = True async def synthesize( self, text: str, speaker: str = None, language: str = None, speaker_wav_b64: str = None, ) -> bytes | None: """Synthesize speech using XTTS API. When *speaker* matches a custom voice from the voice registry the request is enriched with the trained model path so the XTTS backend loads the fine-tuned model instead of the default one. """ start_time = time.time() try: # Check if speaker matches a custom trained voice custom_voice = self.voice_registry.get(speaker) if speaker else None # Build request payload payload = { "text": text, "speaker": speaker or DEFAULT_SPEAKER, "language": language or DEFAULT_LANGUAGE, } if custom_voice: # Custom voice from coqui-voice-training pipeline payload["model_path"] = custom_voice.model_path if custom_voice.config_path: payload["config_path"] = custom_voice.config_path payload["language"] = language or custom_voice.language logger.info( f"Using custom voice '{custom_voice.name}' from {custom_voice.model_path}" ) elif speaker_wav_b64: # Ad-hoc voice cloning via reference audio payload["speaker_wav"] = speaker_wav_b64 # Call XTTS API response = await self.http_client.post(f"{XTTS_URL}/v1/audio/speech", json=payload) response.raise_for_status() audio_bytes = response.content duration = time.time() - start_time audio_duration = len(audio_bytes) / (SAMPLE_RATE * 2) # 16-bit audio rtf = duration / audio_duration if audio_duration > 0 else 0 voice_label = custom_voice.name if custom_voice else (speaker or DEFAULT_SPEAKER) logger.info( f"Synthesized {len(audio_bytes)} bytes in {duration:.2f}s (RTF: {rtf:.2f}, voice: {voice_label})" ) if self.synthesis_duration: self.synthesis_duration.record(duration, {"speaker": voice_label}) return audio_bytes except Exception as e: logger.error(f"Synthesis failed: {e}") return None async def stream_audio(self, session_id: str, audio_bytes: bytes): """Stream audio back to client in chunks.""" total_chunks = (len(audio_bytes) + AUDIO_CHUNK_SIZE - 1) // AUDIO_CHUNK_SIZE for i in range(0, len(audio_bytes), AUDIO_CHUNK_SIZE): chunk = audio_bytes[i : i + AUDIO_CHUNK_SIZE] chunk_index = i // AUDIO_CHUNK_SIZE is_last = (i + AUDIO_CHUNK_SIZE) >= len(audio_bytes) message = { "session_id": session_id, "chunk_index": chunk_index, "total_chunks": total_chunks, "audio_b64": base64.b64encode(chunk).decode(), "is_last": is_last, "timestamp": time.time(), "sample_rate": SAMPLE_RATE, } await self.nc.publish(f"{AUDIO_SUBJECT_PREFIX}.{session_id}", msgpack.packb(message)) logger.debug(f"Sent chunk {chunk_index + 1}/{total_chunks} for session {session_id}") logger.info(f"Streamed {total_chunks} chunks for session {session_id}") async def handle_request(self, msg: Msg): """Handle incoming TTS request.""" try: # Extract session_id from subject: ai.voice.tts.request.{session_id} subject_parts = msg.subject.split(".") if len(subject_parts) < 5: logger.warning(f"Invalid subject format: {msg.subject}") return session_id = subject_parts[4] # Parse request using msgpack data = msgpack.unpackb(msg.data, raw=False) text = data.get("text", "") speaker = data.get("speaker") language = data.get("language") speaker_wav_b64 = data.get("speaker_wav_b64") # For voice cloning stream = data.get("stream", True) # Default to streaming if not text: logger.warning(f"Empty text for session {session_id}") await self.publish_status(session_id, "error", "Empty text provided") return logger.info(f"Processing TTS request for session {session_id}: {text[:50]}...") if self.synthesis_counter: self.synthesis_counter.add(1, {"session_id": session_id}) # Publish status: processing await self.publish_status( session_id, "processing", f"Synthesizing {len(text)} characters" ) # Synthesize audio audio_bytes = await self.synthesize(text, speaker, language, speaker_wav_b64) if audio_bytes: if stream: # Stream audio in chunks await self.stream_audio(session_id, audio_bytes) else: # Send complete audio in one message message = { "session_id": session_id, "audio_b64": base64.b64encode(audio_bytes).decode(), "timestamp": time.time(), "sample_rate": SAMPLE_RATE, } await self.nc.publish( f"{AUDIO_SUBJECT_PREFIX}.{session_id}", msgpack.packb(message) ) await self.publish_status( session_id, "completed", f"Audio size: {len(audio_bytes)} bytes" ) else: await self.publish_status(session_id, "error", "Synthesis failed") except Exception as e: logger.error(f"Error handling TTS request: {e}", exc_info=True) with contextlib.suppress(Exception): await self.publish_status(session_id, "error", str(e)) async def publish_status(self, session_id: str, status: str, message: str = ""): """Publish TTS status update.""" status_msg = { "session_id": session_id, "status": status, "message": message, "timestamp": time.time(), } await self.nc.publish(f"{STATUS_SUBJECT_PREFIX}.{session_id}", msgpack.packb(status_msg)) logger.debug(f"Published status '{status}' for session {session_id}") async def handle_list_voices(self, msg: Msg): """Handle request to list available voices (built-in + custom).""" custom = self.voice_registry.list_voices() response = { "default_speaker": DEFAULT_SPEAKER, "custom_voices": custom, "last_refresh": self.voice_registry._last_refresh, "timestamp": time.time(), } if msg.reply: await msg.respond(msgpack.packb(response)) logger.debug(f"Listed {len(custom)} custom voice(s)") async def handle_refresh_voices(self, msg: Msg): """Handle request to refresh the custom voice registry.""" count = self.voice_registry.refresh() response = { "count": count, "custom_voices": self.voice_registry.list_voices(), "timestamp": time.time(), } if msg.reply: await msg.respond(msgpack.packb(response)) logger.info(f"Voice registry refreshed on demand: {count} voice(s)") async def _periodic_voice_refresh(self): """Periodically refresh the voice registry to pick up newly trained voices.""" while self.running: await asyncio.sleep(VOICE_REGISTRY_REFRESH_SECONDS) if not self.running: break try: self.voice_registry.refresh() except Exception as e: logger.error(f"Periodic voice registry refresh failed: {e}") async def run(self): """Main run loop.""" await self.setup() # Subscribe to TTS requests sub = await self.nc.subscribe(f"{REQUEST_SUBJECT_PREFIX}.>", cb=self.handle_request) logger.info(f"Subscribed to {REQUEST_SUBJECT_PREFIX}.>") # Subscribe to voice management subjects voices_sub = await self.nc.subscribe(VOICES_LIST_SUBJECT, cb=self.handle_list_voices) refresh_sub = await self.nc.subscribe(VOICES_REFRESH_SUBJECT, cb=self.handle_refresh_voices) logger.info(f"Subscribed to {VOICES_LIST_SUBJECT} and {VOICES_REFRESH_SUBJECT}") # Start periodic voice registry refresh refresh_task = asyncio.create_task(self._periodic_voice_refresh()) # Handle shutdown def signal_handler(): self.running = False loop = asyncio.get_event_loop() for sig in (signal.SIGTERM, signal.SIGINT): loop.add_signal_handler(sig, signal_handler) # Keep running while self.running: await asyncio.sleep(1) # Cleanup logger.info("Shutting down...") refresh_task.cancel() await sub.unsubscribe() await voices_sub.unsubscribe() await refresh_sub.unsubscribe() await self.nc.close() await self.http_client.aclose() logger.info("Shutdown complete") if __name__ == "__main__": service = StreamingTTS() asyncio.run(service.run())