- tts_streaming.py: NATS-based TTS using XTTS HTTP API - Streaming audio chunks for low-latency playback - Voice cloning support via reference audio - Multi-language synthesis - OpenTelemetry instrumentation with HyperDX support
357 lines
14 KiB
Python
357 lines
14 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Streaming TTS Service
|
|
|
|
Real-time Text-to-Speech service that processes synthesis requests from NATS:
|
|
1. Subscribe to TTS requests on "ai.voice.tts.request.{session_id}"
|
|
2. Synthesize speech using Coqui XTTS via HTTP API
|
|
3. Stream audio chunks back via "ai.voice.tts.audio.{session_id}"
|
|
4. Support for voice cloning and multi-speaker synthesis
|
|
|
|
This enables real-time voice synthesis for voice assistant applications.
|
|
"""
|
|
import asyncio
|
|
import base64
|
|
import logging
|
|
import os
|
|
import signal
|
|
import time
|
|
from typing import Dict, Optional
|
|
|
|
import httpx
|
|
import msgpack
|
|
import nats
|
|
import nats.js
|
|
from nats.aio.msg import Msg
|
|
|
|
# OpenTelemetry imports
|
|
from opentelemetry import trace, metrics
|
|
from opentelemetry.sdk.trace import TracerProvider
|
|
from opentelemetry.sdk.trace.export import BatchSpanProcessor
|
|
from opentelemetry.sdk.metrics import MeterProvider
|
|
from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader
|
|
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
|
|
from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import OTLPMetricExporter
|
|
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter as OTLPSpanExporterHTTP
|
|
from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter as OTLPMetricExporterHTTP
|
|
from opentelemetry.sdk.resources import Resource, SERVICE_NAME, SERVICE_VERSION, SERVICE_NAMESPACE
|
|
from opentelemetry.instrumentation.httpx import HTTPXClientInstrumentor
|
|
from opentelemetry.instrumentation.logging import LoggingInstrumentor
|
|
|
|
# Configure logging
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
|
)
|
|
logger = logging.getLogger("tts-streaming")
|
|
|
|
|
|
def setup_telemetry():
|
|
"""Initialize OpenTelemetry tracing and metrics with HyperDX support."""
|
|
otel_enabled = os.environ.get("OTEL_ENABLED", "true").lower() == "true"
|
|
if not otel_enabled:
|
|
logger.info("OpenTelemetry disabled")
|
|
return None, None
|
|
|
|
otel_endpoint = os.environ.get("OTEL_EXPORTER_OTLP_ENDPOINT", "http://opentelemetry-collector.observability.svc.cluster.local:4317")
|
|
service_name = os.environ.get("OTEL_SERVICE_NAME", "tts-streaming")
|
|
service_namespace = os.environ.get("OTEL_SERVICE_NAMESPACE", "ai-ml")
|
|
|
|
hyperdx_api_key = os.environ.get("HYPERDX_API_KEY", "")
|
|
hyperdx_endpoint = os.environ.get("HYPERDX_ENDPOINT", "https://in-otel.hyperdx.io")
|
|
use_hyperdx = os.environ.get("HYPERDX_ENABLED", "false").lower() == "true" and hyperdx_api_key
|
|
|
|
resource = Resource.create({
|
|
SERVICE_NAME: service_name,
|
|
SERVICE_VERSION: os.environ.get("SERVICE_VERSION", "1.0.0"),
|
|
SERVICE_NAMESPACE: service_namespace,
|
|
"deployment.environment": os.environ.get("DEPLOYMENT_ENV", "production"),
|
|
"host.name": os.environ.get("HOSTNAME", "unknown"),
|
|
})
|
|
|
|
trace_provider = TracerProvider(resource=resource)
|
|
|
|
if use_hyperdx:
|
|
logger.info(f"Configuring HyperDX exporter at {hyperdx_endpoint}")
|
|
headers = {"authorization": hyperdx_api_key}
|
|
otlp_span_exporter = OTLPSpanExporterHTTP(
|
|
endpoint=f"{hyperdx_endpoint}/v1/traces",
|
|
headers=headers
|
|
)
|
|
otlp_metric_exporter = OTLPMetricExporterHTTP(
|
|
endpoint=f"{hyperdx_endpoint}/v1/metrics",
|
|
headers=headers
|
|
)
|
|
else:
|
|
otlp_span_exporter = OTLPSpanExporter(endpoint=otel_endpoint, insecure=True)
|
|
otlp_metric_exporter = OTLPMetricExporter(endpoint=otel_endpoint, insecure=True)
|
|
|
|
trace_provider.add_span_processor(BatchSpanProcessor(otlp_span_exporter))
|
|
trace.set_tracer_provider(trace_provider)
|
|
|
|
metric_reader = PeriodicExportingMetricReader(otlp_metric_exporter, export_interval_millis=60000)
|
|
meter_provider = MeterProvider(resource=resource, metric_readers=[metric_reader])
|
|
metrics.set_meter_provider(meter_provider)
|
|
|
|
HTTPXClientInstrumentor().instrument()
|
|
LoggingInstrumentor().instrument(set_logging_format=True)
|
|
|
|
destination = "HyperDX" if use_hyperdx else "OTEL Collector"
|
|
logger.info(f"OpenTelemetry initialized - destination: {destination}, service: {service_name}")
|
|
|
|
return trace.get_tracer(__name__), metrics.get_meter(__name__)
|
|
|
|
|
|
# Configuration from environment
|
|
XTTS_URL = os.environ.get("XTTS_URL", "http://xtts-predictor.ai-ml.svc.cluster.local")
|
|
NATS_URL = os.environ.get("NATS_URL", "nats://nats.ai-ml.svc.cluster.local:4222")
|
|
|
|
# NATS subjects
|
|
REQUEST_SUBJECT_PREFIX = "ai.voice.tts.request" # ai.voice.tts.request.{session_id}
|
|
AUDIO_SUBJECT_PREFIX = "ai.voice.tts.audio" # ai.voice.tts.audio.{session_id}
|
|
STATUS_SUBJECT_PREFIX = "ai.voice.tts.status" # ai.voice.tts.status.{session_id}
|
|
|
|
# TTS parameters
|
|
DEFAULT_SPEAKER = os.environ.get("TTS_DEFAULT_SPEAKER", "default")
|
|
DEFAULT_LANGUAGE = os.environ.get("TTS_DEFAULT_LANGUAGE", "en")
|
|
AUDIO_CHUNK_SIZE = int(os.environ.get("TTS_AUDIO_CHUNK_SIZE", "32768")) # 32KB chunks for streaming
|
|
SAMPLE_RATE = int(os.environ.get("TTS_SAMPLE_RATE", "24000")) # XTTS default sample rate
|
|
|
|
|
|
class StreamingTTS:
|
|
"""Streaming Text-to-Speech service using Coqui XTTS."""
|
|
|
|
def __init__(self):
|
|
self.nc = None
|
|
self.js = None
|
|
self.http_client = None
|
|
self.running = True
|
|
self.is_healthy = False
|
|
self.tracer = None
|
|
self.meter = None
|
|
self.synthesis_counter = None
|
|
self.synthesis_duration = None
|
|
self.active_sessions: Dict[str, dict] = {}
|
|
|
|
async def setup(self):
|
|
"""Initialize connections."""
|
|
self.tracer, self.meter = setup_telemetry()
|
|
|
|
if self.meter:
|
|
self.synthesis_counter = self.meter.create_counter(
|
|
name="tts_synthesis_total",
|
|
description="Total number of TTS synthesis requests",
|
|
unit="1"
|
|
)
|
|
self.synthesis_duration = self.meter.create_histogram(
|
|
name="tts_synthesis_duration_seconds",
|
|
description="Duration of TTS synthesis",
|
|
unit="s"
|
|
)
|
|
|
|
# NATS connection
|
|
self.nc = await nats.connect(NATS_URL)
|
|
logger.info(f"Connected to NATS at {NATS_URL}")
|
|
|
|
# Initialize JetStream context
|
|
self.js = self.nc.jetstream()
|
|
|
|
# Create or update stream for TTS messages
|
|
try:
|
|
stream_config = nats.js.api.StreamConfig(
|
|
name="AI_VOICE_TTS",
|
|
subjects=["ai.voice.tts.>"],
|
|
retention=nats.js.api.RetentionPolicy.LIMITS,
|
|
max_age=300, # 5 minutes
|
|
storage=nats.js.api.StorageType.MEMORY,
|
|
)
|
|
await self.js.add_stream(stream_config)
|
|
logger.info("Created/updated JetStream stream: AI_VOICE_TTS")
|
|
except Exception as e:
|
|
logger.info(f"JetStream stream setup: {e}")
|
|
|
|
# HTTP client for XTTS service
|
|
self.http_client = httpx.AsyncClient(timeout=180.0)
|
|
logger.info("HTTP client initialized")
|
|
|
|
self.is_healthy = True
|
|
|
|
async def synthesize(self, text: str, speaker: str = None, language: str = None,
|
|
speaker_wav_b64: str = None) -> Optional[bytes]:
|
|
"""Synthesize speech using XTTS API."""
|
|
start_time = time.time()
|
|
|
|
try:
|
|
# Build request payload
|
|
payload = {
|
|
"text": text,
|
|
"speaker": speaker or DEFAULT_SPEAKER,
|
|
"language": language or DEFAULT_LANGUAGE,
|
|
}
|
|
|
|
# Add speaker reference audio for voice cloning if provided
|
|
if speaker_wav_b64:
|
|
payload["speaker_wav"] = speaker_wav_b64
|
|
|
|
# Call XTTS API
|
|
response = await self.http_client.post(
|
|
f"{XTTS_URL}/v1/audio/speech",
|
|
json=payload
|
|
)
|
|
response.raise_for_status()
|
|
|
|
audio_bytes = response.content
|
|
|
|
duration = time.time() - start_time
|
|
audio_duration = len(audio_bytes) / (SAMPLE_RATE * 2) # 16-bit audio
|
|
rtf = duration / audio_duration if audio_duration > 0 else 0
|
|
|
|
logger.info(f"Synthesized {len(audio_bytes)} bytes in {duration:.2f}s (RTF: {rtf:.2f})")
|
|
|
|
if self.synthesis_duration:
|
|
self.synthesis_duration.record(duration, {"speaker": speaker or DEFAULT_SPEAKER})
|
|
|
|
return audio_bytes
|
|
|
|
except Exception as e:
|
|
logger.error(f"Synthesis failed: {e}")
|
|
return None
|
|
|
|
async def stream_audio(self, session_id: str, audio_bytes: bytes):
|
|
"""Stream audio back to client in chunks."""
|
|
total_chunks = (len(audio_bytes) + AUDIO_CHUNK_SIZE - 1) // AUDIO_CHUNK_SIZE
|
|
|
|
for i in range(0, len(audio_bytes), AUDIO_CHUNK_SIZE):
|
|
chunk = audio_bytes[i:i + AUDIO_CHUNK_SIZE]
|
|
chunk_index = i // AUDIO_CHUNK_SIZE
|
|
is_last = (i + AUDIO_CHUNK_SIZE) >= len(audio_bytes)
|
|
|
|
message = {
|
|
"session_id": session_id,
|
|
"chunk_index": chunk_index,
|
|
"total_chunks": total_chunks,
|
|
"audio_b64": base64.b64encode(chunk).decode(),
|
|
"is_last": is_last,
|
|
"timestamp": time.time(),
|
|
"sample_rate": SAMPLE_RATE,
|
|
}
|
|
|
|
await self.nc.publish(
|
|
f"{AUDIO_SUBJECT_PREFIX}.{session_id}",
|
|
msgpack.packb(message)
|
|
)
|
|
|
|
logger.debug(f"Sent chunk {chunk_index + 1}/{total_chunks} for session {session_id}")
|
|
|
|
logger.info(f"Streamed {total_chunks} chunks for session {session_id}")
|
|
|
|
async def handle_request(self, msg: Msg):
|
|
"""Handle incoming TTS request."""
|
|
try:
|
|
# Extract session_id from subject: ai.voice.tts.request.{session_id}
|
|
subject_parts = msg.subject.split('.')
|
|
if len(subject_parts) < 5:
|
|
logger.warning(f"Invalid subject format: {msg.subject}")
|
|
return
|
|
|
|
session_id = subject_parts[4]
|
|
|
|
# Parse request using msgpack
|
|
data = msgpack.unpackb(msg.data, raw=False)
|
|
|
|
text = data.get("text", "")
|
|
speaker = data.get("speaker")
|
|
language = data.get("language")
|
|
speaker_wav_b64 = data.get("speaker_wav_b64") # For voice cloning
|
|
stream = data.get("stream", True) # Default to streaming
|
|
|
|
if not text:
|
|
logger.warning(f"Empty text for session {session_id}")
|
|
await self.publish_status(session_id, "error", "Empty text provided")
|
|
return
|
|
|
|
logger.info(f"Processing TTS request for session {session_id}: {text[:50]}...")
|
|
|
|
if self.synthesis_counter:
|
|
self.synthesis_counter.add(1, {"session_id": session_id})
|
|
|
|
# Publish status: processing
|
|
await self.publish_status(session_id, "processing", f"Synthesizing {len(text)} characters")
|
|
|
|
# Synthesize audio
|
|
audio_bytes = await self.synthesize(text, speaker, language, speaker_wav_b64)
|
|
|
|
if audio_bytes:
|
|
if stream:
|
|
# Stream audio in chunks
|
|
await self.stream_audio(session_id, audio_bytes)
|
|
else:
|
|
# Send complete audio in one message
|
|
message = {
|
|
"session_id": session_id,
|
|
"audio_b64": base64.b64encode(audio_bytes).decode(),
|
|
"timestamp": time.time(),
|
|
"sample_rate": SAMPLE_RATE,
|
|
}
|
|
await self.nc.publish(
|
|
f"{AUDIO_SUBJECT_PREFIX}.{session_id}",
|
|
msgpack.packb(message)
|
|
)
|
|
|
|
await self.publish_status(session_id, "completed", f"Audio size: {len(audio_bytes)} bytes")
|
|
else:
|
|
await self.publish_status(session_id, "error", "Synthesis failed")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error handling TTS request: {e}", exc_info=True)
|
|
try:
|
|
await self.publish_status(session_id, "error", str(e))
|
|
except:
|
|
pass
|
|
|
|
async def publish_status(self, session_id: str, status: str, message: str = ""):
|
|
"""Publish TTS status update."""
|
|
status_msg = {
|
|
"session_id": session_id,
|
|
"status": status,
|
|
"message": message,
|
|
"timestamp": time.time(),
|
|
}
|
|
await self.nc.publish(
|
|
f"{STATUS_SUBJECT_PREFIX}.{session_id}",
|
|
msgpack.packb(status_msg)
|
|
)
|
|
logger.debug(f"Published status '{status}' for session {session_id}")
|
|
|
|
async def run(self):
|
|
"""Main run loop."""
|
|
await self.setup()
|
|
|
|
# Subscribe to TTS requests
|
|
sub = await self.nc.subscribe(f"{REQUEST_SUBJECT_PREFIX}.>", cb=self.handle_request)
|
|
logger.info(f"Subscribed to {REQUEST_SUBJECT_PREFIX}.>")
|
|
|
|
# Handle shutdown
|
|
def signal_handler():
|
|
self.running = False
|
|
|
|
loop = asyncio.get_event_loop()
|
|
for sig in (signal.SIGTERM, signal.SIGINT):
|
|
loop.add_signal_handler(sig, signal_handler)
|
|
|
|
# Keep running
|
|
while self.running:
|
|
await asyncio.sleep(1)
|
|
|
|
# Cleanup
|
|
logger.info("Shutting down...")
|
|
await sub.unsubscribe()
|
|
await self.nc.close()
|
|
await self.http_client.aclose()
|
|
logger.info("Shutdown complete")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
service = StreamingTTS()
|
|
asyncio.run(service.run())
|