Files
tts-module/tts_streaming.py
Billy D. 7b3bfc6812
Some checks failed
CI / Test (push) Successful in 43s
CI / Lint (push) Successful in 44s
CI / Docker Build & Push (push) Failing after 24s
CI / Notify (push) Successful in 1s
CI / Release (push) Successful in 6s
feat: custom voice support, CI pipeline, and Renovate config
- VoiceRegistry for trained voices from Kubeflow pipeline
- Custom voice routing in synthesize()
- NATS subjects for listing/refreshing voices
- pyproject.toml with ruff/pytest config
- Full test suite (26 tests)
- Gitea Actions CI (lint, test, docker, notify)
- Renovate config for automated dependency updates

Ref: ADR-0056, ADR-0057
2026-02-13 15:33:27 -05:00

543 lines
20 KiB
Python

#!/usr/bin/env python3
"""
Streaming TTS Service
Real-time Text-to-Speech service that processes synthesis requests from NATS:
1. Subscribe to TTS requests on "ai.voice.tts.request.{session_id}"
2. Synthesize speech using Coqui XTTS via HTTP API
3. Stream audio chunks back via "ai.voice.tts.audio.{session_id}"
4. Support for voice cloning and multi-speaker synthesis
This enables real-time voice synthesis for voice assistant applications.
"""
import asyncio
import base64
import contextlib
import json
import logging
import os
import signal
import time
from dataclasses import dataclass
from pathlib import Path
import httpx
import msgpack
import nats
import nats.js
from nats.aio.msg import Msg
# OpenTelemetry imports
from opentelemetry import metrics, trace
from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import OTLPMetricExporter
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
from opentelemetry.exporter.otlp.proto.http.metric_exporter import (
OTLPMetricExporter as OTLPMetricExporterHTTP,
)
from opentelemetry.exporter.otlp.proto.http.trace_exporter import (
OTLPSpanExporter as OTLPSpanExporterHTTP,
)
from opentelemetry.instrumentation.httpx import HTTPXClientInstrumentor
from opentelemetry.instrumentation.logging import LoggingInstrumentor
from opentelemetry.sdk.metrics import MeterProvider
from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader
from opentelemetry.sdk.resources import SERVICE_NAME, SERVICE_NAMESPACE, SERVICE_VERSION, Resource
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor
# Configure logging
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger("tts-streaming")
def setup_telemetry():
"""Initialize OpenTelemetry tracing and metrics with HyperDX support."""
otel_enabled = os.environ.get("OTEL_ENABLED", "true").lower() == "true"
if not otel_enabled:
logger.info("OpenTelemetry disabled")
return None, None
otel_endpoint = os.environ.get(
"OTEL_EXPORTER_OTLP_ENDPOINT",
"http://opentelemetry-collector.observability.svc.cluster.local:4317",
)
service_name = os.environ.get("OTEL_SERVICE_NAME", "tts-streaming")
service_namespace = os.environ.get("OTEL_SERVICE_NAMESPACE", "ai-ml")
hyperdx_api_key = os.environ.get("HYPERDX_API_KEY", "")
hyperdx_endpoint = os.environ.get("HYPERDX_ENDPOINT", "https://in-otel.hyperdx.io")
use_hyperdx = os.environ.get("HYPERDX_ENABLED", "false").lower() == "true" and hyperdx_api_key
resource = Resource.create(
{
SERVICE_NAME: service_name,
SERVICE_VERSION: os.environ.get("SERVICE_VERSION", "1.0.0"),
SERVICE_NAMESPACE: service_namespace,
"deployment.environment": os.environ.get("DEPLOYMENT_ENV", "production"),
"host.name": os.environ.get("HOSTNAME", "unknown"),
}
)
trace_provider = TracerProvider(resource=resource)
if use_hyperdx:
logger.info(f"Configuring HyperDX exporter at {hyperdx_endpoint}")
headers = {"authorization": hyperdx_api_key}
otlp_span_exporter = OTLPSpanExporterHTTP(
endpoint=f"{hyperdx_endpoint}/v1/traces", headers=headers
)
otlp_metric_exporter = OTLPMetricExporterHTTP(
endpoint=f"{hyperdx_endpoint}/v1/metrics", headers=headers
)
else:
otlp_span_exporter = OTLPSpanExporter(endpoint=otel_endpoint, insecure=True)
otlp_metric_exporter = OTLPMetricExporter(endpoint=otel_endpoint, insecure=True)
trace_provider.add_span_processor(BatchSpanProcessor(otlp_span_exporter))
trace.set_tracer_provider(trace_provider)
metric_reader = PeriodicExportingMetricReader(
otlp_metric_exporter, export_interval_millis=60000
)
meter_provider = MeterProvider(resource=resource, metric_readers=[metric_reader])
metrics.set_meter_provider(meter_provider)
HTTPXClientInstrumentor().instrument()
LoggingInstrumentor().instrument(set_logging_format=True)
destination = "HyperDX" if use_hyperdx else "OTEL Collector"
logger.info(f"OpenTelemetry initialized - destination: {destination}, service: {service_name}")
return trace.get_tracer(__name__), metrics.get_meter(__name__)
# Configuration from environment
XTTS_URL = os.environ.get("XTTS_URL", "http://xtts-predictor.ai-ml.svc.cluster.local")
NATS_URL = os.environ.get("NATS_URL", "nats://nats.ai-ml.svc.cluster.local:4222")
# NATS subjects
REQUEST_SUBJECT_PREFIX = "ai.voice.tts.request" # ai.voice.tts.request.{session_id}
AUDIO_SUBJECT_PREFIX = "ai.voice.tts.audio" # ai.voice.tts.audio.{session_id}
STATUS_SUBJECT_PREFIX = "ai.voice.tts.status" # ai.voice.tts.status.{session_id}
VOICES_LIST_SUBJECT = "ai.voice.tts.voices.list" # List available voices
VOICES_REFRESH_SUBJECT = "ai.voice.tts.voices.refresh" # Trigger registry refresh
# TTS parameters
DEFAULT_SPEAKER = os.environ.get("TTS_DEFAULT_SPEAKER", "default")
DEFAULT_LANGUAGE = os.environ.get("TTS_DEFAULT_LANGUAGE", "en")
AUDIO_CHUNK_SIZE = int(os.environ.get("TTS_AUDIO_CHUNK_SIZE", "32768")) # 32KB chunks for streaming
SAMPLE_RATE = int(os.environ.get("TTS_SAMPLE_RATE", "24000")) # XTTS default sample rate
# Custom voice model store (populated by coqui-voice-training Argo workflow)
VOICE_MODEL_STORE = os.environ.get("VOICE_MODEL_STORE", "/models/tts/custom")
VOICE_REGISTRY_REFRESH_SECONDS = int(os.environ.get("VOICE_REGISTRY_REFRESH_SECONDS", "300"))
@dataclass
class CustomVoice:
"""A custom trained voice produced by the coqui-voice-training pipeline."""
name: str
model_path: str
config_path: str
created_at: str
language: str = "en"
model_type: str = "coqui-tts"
class VoiceRegistry:
"""Registry of custom trained voices discovered from the model store.
Scans ``VOICE_MODEL_STORE`` for directories produced by the
``coqui-voice-training`` Argo workflow. Each directory must contain
``model_info.json`` and ``model.pth``.
"""
def __init__(self, model_store_path: str) -> None:
self.model_store = Path(model_store_path)
self.voices: dict[str, CustomVoice] = {}
self._last_refresh: float = 0.0
def refresh(self) -> int:
"""Scan the model store for available custom voices.
Returns:
Number of voices discovered.
"""
if not self.model_store.exists():
logger.warning(f"Voice model store not found: {self.model_store}")
return 0
discovered: dict[str, CustomVoice] = {}
for voice_dir in self.model_store.iterdir():
if not voice_dir.is_dir():
continue
model_info_path = voice_dir / "model_info.json"
if not model_info_path.exists():
continue
try:
with open(model_info_path) as f:
info = json.load(f)
model_path = voice_dir / "model.pth"
config_path = voice_dir / "config.json"
if not model_path.exists():
logger.warning(f"Model file missing for voice: {voice_dir.name}")
continue
voice = CustomVoice(
name=info.get("name", voice_dir.name),
model_path=str(model_path),
config_path=str(config_path) if config_path.exists() else "",
created_at=info.get("created_at", ""),
language=info.get("language", "en"),
model_type=info.get("type", "coqui-tts"),
)
discovered[voice.name] = voice
except Exception as e:
logger.error(f"Failed to load voice info from {voice_dir}: {e}")
continue
added = set(discovered) - set(self.voices)
removed = set(self.voices) - set(discovered)
if added:
logger.info(f"New voices discovered: {', '.join(sorted(added))}")
if removed:
logger.info(f"Voices removed: {', '.join(sorted(removed))}")
self.voices = discovered
self._last_refresh = time.time()
logger.info(f"Voice registry refreshed: {len(self.voices)} custom voice(s) available")
return len(self.voices)
def get(self, name: str) -> CustomVoice | None:
"""Get a custom voice by name."""
return self.voices.get(name)
def list_voices(self) -> list[dict]:
"""List all available custom voices as serialisable dicts."""
return [
{
"name": v.name,
"language": v.language,
"model_type": v.model_type,
"created_at": v.created_at,
}
for v in self.voices.values()
]
class StreamingTTS:
"""Streaming Text-to-Speech service using Coqui XTTS."""
def __init__(self):
self.nc = None
self.js = None
self.http_client = None
self.running = True
self.is_healthy = False
self.tracer = None
self.meter = None
self.synthesis_counter = None
self.synthesis_duration = None
self.active_sessions: dict[str, dict] = {}
self.voice_registry = VoiceRegistry(VOICE_MODEL_STORE)
async def setup(self):
"""Initialize connections."""
self.tracer, self.meter = setup_telemetry()
if self.meter:
self.synthesis_counter = self.meter.create_counter(
name="tts_synthesis_total",
description="Total number of TTS synthesis requests",
unit="1",
)
self.synthesis_duration = self.meter.create_histogram(
name="tts_synthesis_duration_seconds",
description="Duration of TTS synthesis",
unit="s",
)
# NATS connection
self.nc = await nats.connect(NATS_URL)
logger.info(f"Connected to NATS at {NATS_URL}")
# Initialize JetStream context
self.js = self.nc.jetstream()
# Create or update stream for TTS messages
try:
stream_config = nats.js.api.StreamConfig(
name="AI_VOICE_TTS",
subjects=["ai.voice.tts.>"],
retention=nats.js.api.RetentionPolicy.LIMITS,
max_age=300, # 5 minutes
storage=nats.js.api.StorageType.MEMORY,
)
await self.js.add_stream(stream_config)
logger.info("Created/updated JetStream stream: AI_VOICE_TTS")
except Exception as e:
logger.info(f"JetStream stream setup: {e}")
# HTTP client for XTTS service
self.http_client = httpx.AsyncClient(timeout=180.0)
logger.info("HTTP client initialized")
# Discover custom voices from model store
self.voice_registry.refresh()
self.is_healthy = True
async def synthesize(
self,
text: str,
speaker: str = None,
language: str = None,
speaker_wav_b64: str = None,
) -> bytes | None:
"""Synthesize speech using XTTS API.
When *speaker* matches a custom voice from the voice registry the
request is enriched with the trained model path so the XTTS backend
loads the fine-tuned model instead of the default one.
"""
start_time = time.time()
try:
# Check if speaker matches a custom trained voice
custom_voice = self.voice_registry.get(speaker) if speaker else None
# Build request payload
payload = {
"text": text,
"speaker": speaker or DEFAULT_SPEAKER,
"language": language or DEFAULT_LANGUAGE,
}
if custom_voice:
# Custom voice from coqui-voice-training pipeline
payload["model_path"] = custom_voice.model_path
if custom_voice.config_path:
payload["config_path"] = custom_voice.config_path
payload["language"] = language or custom_voice.language
logger.info(
f"Using custom voice '{custom_voice.name}' from {custom_voice.model_path}"
)
elif speaker_wav_b64:
# Ad-hoc voice cloning via reference audio
payload["speaker_wav"] = speaker_wav_b64
# Call XTTS API
response = await self.http_client.post(f"{XTTS_URL}/v1/audio/speech", json=payload)
response.raise_for_status()
audio_bytes = response.content
duration = time.time() - start_time
audio_duration = len(audio_bytes) / (SAMPLE_RATE * 2) # 16-bit audio
rtf = duration / audio_duration if audio_duration > 0 else 0
voice_label = custom_voice.name if custom_voice else (speaker or DEFAULT_SPEAKER)
logger.info(
f"Synthesized {len(audio_bytes)} bytes in {duration:.2f}s (RTF: {rtf:.2f}, voice: {voice_label})"
)
if self.synthesis_duration:
self.synthesis_duration.record(duration, {"speaker": voice_label})
return audio_bytes
except Exception as e:
logger.error(f"Synthesis failed: {e}")
return None
async def stream_audio(self, session_id: str, audio_bytes: bytes):
"""Stream audio back to client in chunks."""
total_chunks = (len(audio_bytes) + AUDIO_CHUNK_SIZE - 1) // AUDIO_CHUNK_SIZE
for i in range(0, len(audio_bytes), AUDIO_CHUNK_SIZE):
chunk = audio_bytes[i : i + AUDIO_CHUNK_SIZE]
chunk_index = i // AUDIO_CHUNK_SIZE
is_last = (i + AUDIO_CHUNK_SIZE) >= len(audio_bytes)
message = {
"session_id": session_id,
"chunk_index": chunk_index,
"total_chunks": total_chunks,
"audio_b64": base64.b64encode(chunk).decode(),
"is_last": is_last,
"timestamp": time.time(),
"sample_rate": SAMPLE_RATE,
}
await self.nc.publish(f"{AUDIO_SUBJECT_PREFIX}.{session_id}", msgpack.packb(message))
logger.debug(f"Sent chunk {chunk_index + 1}/{total_chunks} for session {session_id}")
logger.info(f"Streamed {total_chunks} chunks for session {session_id}")
async def handle_request(self, msg: Msg):
"""Handle incoming TTS request."""
try:
# Extract session_id from subject: ai.voice.tts.request.{session_id}
subject_parts = msg.subject.split(".")
if len(subject_parts) < 5:
logger.warning(f"Invalid subject format: {msg.subject}")
return
session_id = subject_parts[4]
# Parse request using msgpack
data = msgpack.unpackb(msg.data, raw=False)
text = data.get("text", "")
speaker = data.get("speaker")
language = data.get("language")
speaker_wav_b64 = data.get("speaker_wav_b64") # For voice cloning
stream = data.get("stream", True) # Default to streaming
if not text:
logger.warning(f"Empty text for session {session_id}")
await self.publish_status(session_id, "error", "Empty text provided")
return
logger.info(f"Processing TTS request for session {session_id}: {text[:50]}...")
if self.synthesis_counter:
self.synthesis_counter.add(1, {"session_id": session_id})
# Publish status: processing
await self.publish_status(
session_id, "processing", f"Synthesizing {len(text)} characters"
)
# Synthesize audio
audio_bytes = await self.synthesize(text, speaker, language, speaker_wav_b64)
if audio_bytes:
if stream:
# Stream audio in chunks
await self.stream_audio(session_id, audio_bytes)
else:
# Send complete audio in one message
message = {
"session_id": session_id,
"audio_b64": base64.b64encode(audio_bytes).decode(),
"timestamp": time.time(),
"sample_rate": SAMPLE_RATE,
}
await self.nc.publish(
f"{AUDIO_SUBJECT_PREFIX}.{session_id}", msgpack.packb(message)
)
await self.publish_status(
session_id, "completed", f"Audio size: {len(audio_bytes)} bytes"
)
else:
await self.publish_status(session_id, "error", "Synthesis failed")
except Exception as e:
logger.error(f"Error handling TTS request: {e}", exc_info=True)
with contextlib.suppress(Exception):
await self.publish_status(session_id, "error", str(e))
async def publish_status(self, session_id: str, status: str, message: str = ""):
"""Publish TTS status update."""
status_msg = {
"session_id": session_id,
"status": status,
"message": message,
"timestamp": time.time(),
}
await self.nc.publish(f"{STATUS_SUBJECT_PREFIX}.{session_id}", msgpack.packb(status_msg))
logger.debug(f"Published status '{status}' for session {session_id}")
async def handle_list_voices(self, msg: Msg):
"""Handle request to list available voices (built-in + custom)."""
custom = self.voice_registry.list_voices()
response = {
"default_speaker": DEFAULT_SPEAKER,
"custom_voices": custom,
"last_refresh": self.voice_registry._last_refresh,
"timestamp": time.time(),
}
if msg.reply:
await msg.respond(msgpack.packb(response))
logger.debug(f"Listed {len(custom)} custom voice(s)")
async def handle_refresh_voices(self, msg: Msg):
"""Handle request to refresh the custom voice registry."""
count = self.voice_registry.refresh()
response = {
"count": count,
"custom_voices": self.voice_registry.list_voices(),
"timestamp": time.time(),
}
if msg.reply:
await msg.respond(msgpack.packb(response))
logger.info(f"Voice registry refreshed on demand: {count} voice(s)")
async def _periodic_voice_refresh(self):
"""Periodically refresh the voice registry to pick up newly trained voices."""
while self.running:
await asyncio.sleep(VOICE_REGISTRY_REFRESH_SECONDS)
if not self.running:
break
try:
self.voice_registry.refresh()
except Exception as e:
logger.error(f"Periodic voice registry refresh failed: {e}")
async def run(self):
"""Main run loop."""
await self.setup()
# Subscribe to TTS requests
sub = await self.nc.subscribe(f"{REQUEST_SUBJECT_PREFIX}.>", cb=self.handle_request)
logger.info(f"Subscribed to {REQUEST_SUBJECT_PREFIX}.>")
# Subscribe to voice management subjects
voices_sub = await self.nc.subscribe(VOICES_LIST_SUBJECT, cb=self.handle_list_voices)
refresh_sub = await self.nc.subscribe(VOICES_REFRESH_SUBJECT, cb=self.handle_refresh_voices)
logger.info(f"Subscribed to {VOICES_LIST_SUBJECT} and {VOICES_REFRESH_SUBJECT}")
# Start periodic voice registry refresh
refresh_task = asyncio.create_task(self._periodic_voice_refresh())
# Handle shutdown
def signal_handler():
self.running = False
loop = asyncio.get_event_loop()
for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, signal_handler)
# Keep running
while self.running:
await asyncio.sleep(1)
# Cleanup
logger.info("Shutting down...")
refresh_task.cancel()
await sub.unsubscribe()
await voices_sub.unsubscribe()
await refresh_sub.unsubscribe()
await self.nc.close()
await self.http_client.aclose()
logger.info("Shutdown complete")
if __name__ == "__main__":
service = StreamingTTS()
asyncio.run(service.run())