feat: custom voice support, CI pipeline, and Renovate config
- VoiceRegistry for trained voices from Kubeflow pipeline - Custom voice routing in synthesize() - NATS subjects for listing/refreshing voices - pyproject.toml with ruff/pytest config - Full test suite (26 tests) - Gitea Actions CI (lint, test, docker, notify) - Renovate config for automated dependency updates Ref: ADR-0056, ADR-0057
This commit is contained in:
404
tts_streaming.py
404
tts_streaming.py
@@ -10,13 +10,17 @@ Real-time Text-to-Speech service that processes synthesis requests from NATS:
|
||||
|
||||
This enables real-time voice synthesis for voice assistant applications.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import contextlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import time
|
||||
from typing import Dict, Optional
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
import msgpack
|
||||
@@ -25,23 +29,26 @@ import nats.js
|
||||
from nats.aio.msg import Msg
|
||||
|
||||
# OpenTelemetry imports
|
||||
from opentelemetry import trace, metrics
|
||||
from opentelemetry.sdk.trace import TracerProvider
|
||||
from opentelemetry.sdk.trace.export import BatchSpanProcessor
|
||||
from opentelemetry.sdk.metrics import MeterProvider
|
||||
from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader
|
||||
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
|
||||
from opentelemetry import metrics, trace
|
||||
from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import OTLPMetricExporter
|
||||
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter as OTLPSpanExporterHTTP
|
||||
from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter as OTLPMetricExporterHTTP
|
||||
from opentelemetry.sdk.resources import Resource, SERVICE_NAME, SERVICE_VERSION, SERVICE_NAMESPACE
|
||||
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
|
||||
from opentelemetry.exporter.otlp.proto.http.metric_exporter import (
|
||||
OTLPMetricExporter as OTLPMetricExporterHTTP,
|
||||
)
|
||||
from opentelemetry.exporter.otlp.proto.http.trace_exporter import (
|
||||
OTLPSpanExporter as OTLPSpanExporterHTTP,
|
||||
)
|
||||
from opentelemetry.instrumentation.httpx import HTTPXClientInstrumentor
|
||||
from opentelemetry.instrumentation.logging import LoggingInstrumentor
|
||||
from opentelemetry.sdk.metrics import MeterProvider
|
||||
from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader
|
||||
from opentelemetry.sdk.resources import SERVICE_NAME, SERVICE_NAMESPACE, SERVICE_VERSION, Resource
|
||||
from opentelemetry.sdk.trace import TracerProvider
|
||||
from opentelemetry.sdk.trace.export import BatchSpanProcessor
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
logger = logging.getLogger("tts-streaming")
|
||||
|
||||
@@ -52,53 +59,58 @@ def setup_telemetry():
|
||||
if not otel_enabled:
|
||||
logger.info("OpenTelemetry disabled")
|
||||
return None, None
|
||||
|
||||
otel_endpoint = os.environ.get("OTEL_EXPORTER_OTLP_ENDPOINT", "http://opentelemetry-collector.observability.svc.cluster.local:4317")
|
||||
|
||||
otel_endpoint = os.environ.get(
|
||||
"OTEL_EXPORTER_OTLP_ENDPOINT",
|
||||
"http://opentelemetry-collector.observability.svc.cluster.local:4317",
|
||||
)
|
||||
service_name = os.environ.get("OTEL_SERVICE_NAME", "tts-streaming")
|
||||
service_namespace = os.environ.get("OTEL_SERVICE_NAMESPACE", "ai-ml")
|
||||
|
||||
|
||||
hyperdx_api_key = os.environ.get("HYPERDX_API_KEY", "")
|
||||
hyperdx_endpoint = os.environ.get("HYPERDX_ENDPOINT", "https://in-otel.hyperdx.io")
|
||||
use_hyperdx = os.environ.get("HYPERDX_ENABLED", "false").lower() == "true" and hyperdx_api_key
|
||||
|
||||
resource = Resource.create({
|
||||
SERVICE_NAME: service_name,
|
||||
SERVICE_VERSION: os.environ.get("SERVICE_VERSION", "1.0.0"),
|
||||
SERVICE_NAMESPACE: service_namespace,
|
||||
"deployment.environment": os.environ.get("DEPLOYMENT_ENV", "production"),
|
||||
"host.name": os.environ.get("HOSTNAME", "unknown"),
|
||||
})
|
||||
|
||||
|
||||
resource = Resource.create(
|
||||
{
|
||||
SERVICE_NAME: service_name,
|
||||
SERVICE_VERSION: os.environ.get("SERVICE_VERSION", "1.0.0"),
|
||||
SERVICE_NAMESPACE: service_namespace,
|
||||
"deployment.environment": os.environ.get("DEPLOYMENT_ENV", "production"),
|
||||
"host.name": os.environ.get("HOSTNAME", "unknown"),
|
||||
}
|
||||
)
|
||||
|
||||
trace_provider = TracerProvider(resource=resource)
|
||||
|
||||
|
||||
if use_hyperdx:
|
||||
logger.info(f"Configuring HyperDX exporter at {hyperdx_endpoint}")
|
||||
headers = {"authorization": hyperdx_api_key}
|
||||
otlp_span_exporter = OTLPSpanExporterHTTP(
|
||||
endpoint=f"{hyperdx_endpoint}/v1/traces",
|
||||
headers=headers
|
||||
endpoint=f"{hyperdx_endpoint}/v1/traces", headers=headers
|
||||
)
|
||||
otlp_metric_exporter = OTLPMetricExporterHTTP(
|
||||
endpoint=f"{hyperdx_endpoint}/v1/metrics",
|
||||
headers=headers
|
||||
endpoint=f"{hyperdx_endpoint}/v1/metrics", headers=headers
|
||||
)
|
||||
else:
|
||||
otlp_span_exporter = OTLPSpanExporter(endpoint=otel_endpoint, insecure=True)
|
||||
otlp_metric_exporter = OTLPMetricExporter(endpoint=otel_endpoint, insecure=True)
|
||||
|
||||
|
||||
trace_provider.add_span_processor(BatchSpanProcessor(otlp_span_exporter))
|
||||
trace.set_tracer_provider(trace_provider)
|
||||
|
||||
metric_reader = PeriodicExportingMetricReader(otlp_metric_exporter, export_interval_millis=60000)
|
||||
|
||||
metric_reader = PeriodicExportingMetricReader(
|
||||
otlp_metric_exporter, export_interval_millis=60000
|
||||
)
|
||||
meter_provider = MeterProvider(resource=resource, metric_readers=[metric_reader])
|
||||
metrics.set_meter_provider(meter_provider)
|
||||
|
||||
|
||||
HTTPXClientInstrumentor().instrument()
|
||||
LoggingInstrumentor().instrument(set_logging_format=True)
|
||||
|
||||
|
||||
destination = "HyperDX" if use_hyperdx else "OTEL Collector"
|
||||
logger.info(f"OpenTelemetry initialized - destination: {destination}, service: {service_name}")
|
||||
|
||||
|
||||
return trace.get_tracer(__name__), metrics.get_meter(__name__)
|
||||
|
||||
|
||||
@@ -110,6 +122,8 @@ NATS_URL = os.environ.get("NATS_URL", "nats://nats.ai-ml.svc.cluster.local:4222"
|
||||
REQUEST_SUBJECT_PREFIX = "ai.voice.tts.request" # ai.voice.tts.request.{session_id}
|
||||
AUDIO_SUBJECT_PREFIX = "ai.voice.tts.audio" # ai.voice.tts.audio.{session_id}
|
||||
STATUS_SUBJECT_PREFIX = "ai.voice.tts.status" # ai.voice.tts.status.{session_id}
|
||||
VOICES_LIST_SUBJECT = "ai.voice.tts.voices.list" # List available voices
|
||||
VOICES_REFRESH_SUBJECT = "ai.voice.tts.voices.refresh" # Trigger registry refresh
|
||||
|
||||
# TTS parameters
|
||||
DEFAULT_SPEAKER = os.environ.get("TTS_DEFAULT_SPEAKER", "default")
|
||||
@@ -117,10 +131,114 @@ DEFAULT_LANGUAGE = os.environ.get("TTS_DEFAULT_LANGUAGE", "en")
|
||||
AUDIO_CHUNK_SIZE = int(os.environ.get("TTS_AUDIO_CHUNK_SIZE", "32768")) # 32KB chunks for streaming
|
||||
SAMPLE_RATE = int(os.environ.get("TTS_SAMPLE_RATE", "24000")) # XTTS default sample rate
|
||||
|
||||
# Custom voice model store (populated by coqui-voice-training Argo workflow)
|
||||
VOICE_MODEL_STORE = os.environ.get("VOICE_MODEL_STORE", "/models/tts/custom")
|
||||
VOICE_REGISTRY_REFRESH_SECONDS = int(os.environ.get("VOICE_REGISTRY_REFRESH_SECONDS", "300"))
|
||||
|
||||
|
||||
@dataclass
|
||||
class CustomVoice:
|
||||
"""A custom trained voice produced by the coqui-voice-training pipeline."""
|
||||
|
||||
name: str
|
||||
model_path: str
|
||||
config_path: str
|
||||
created_at: str
|
||||
language: str = "en"
|
||||
model_type: str = "coqui-tts"
|
||||
|
||||
|
||||
class VoiceRegistry:
|
||||
"""Registry of custom trained voices discovered from the model store.
|
||||
|
||||
Scans ``VOICE_MODEL_STORE`` for directories produced by the
|
||||
``coqui-voice-training`` Argo workflow. Each directory must contain
|
||||
``model_info.json`` and ``model.pth``.
|
||||
"""
|
||||
|
||||
def __init__(self, model_store_path: str) -> None:
|
||||
self.model_store = Path(model_store_path)
|
||||
self.voices: dict[str, CustomVoice] = {}
|
||||
self._last_refresh: float = 0.0
|
||||
|
||||
def refresh(self) -> int:
|
||||
"""Scan the model store for available custom voices.
|
||||
|
||||
Returns:
|
||||
Number of voices discovered.
|
||||
"""
|
||||
if not self.model_store.exists():
|
||||
logger.warning(f"Voice model store not found: {self.model_store}")
|
||||
return 0
|
||||
|
||||
discovered: dict[str, CustomVoice] = {}
|
||||
for voice_dir in self.model_store.iterdir():
|
||||
if not voice_dir.is_dir():
|
||||
continue
|
||||
|
||||
model_info_path = voice_dir / "model_info.json"
|
||||
if not model_info_path.exists():
|
||||
continue
|
||||
|
||||
try:
|
||||
with open(model_info_path) as f:
|
||||
info = json.load(f)
|
||||
|
||||
model_path = voice_dir / "model.pth"
|
||||
config_path = voice_dir / "config.json"
|
||||
|
||||
if not model_path.exists():
|
||||
logger.warning(f"Model file missing for voice: {voice_dir.name}")
|
||||
continue
|
||||
|
||||
voice = CustomVoice(
|
||||
name=info.get("name", voice_dir.name),
|
||||
model_path=str(model_path),
|
||||
config_path=str(config_path) if config_path.exists() else "",
|
||||
created_at=info.get("created_at", ""),
|
||||
language=info.get("language", "en"),
|
||||
model_type=info.get("type", "coqui-tts"),
|
||||
)
|
||||
discovered[voice.name] = voice
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load voice info from {voice_dir}: {e}")
|
||||
continue
|
||||
|
||||
added = set(discovered) - set(self.voices)
|
||||
removed = set(self.voices) - set(discovered)
|
||||
|
||||
if added:
|
||||
logger.info(f"New voices discovered: {', '.join(sorted(added))}")
|
||||
if removed:
|
||||
logger.info(f"Voices removed: {', '.join(sorted(removed))}")
|
||||
|
||||
self.voices = discovered
|
||||
self._last_refresh = time.time()
|
||||
|
||||
logger.info(f"Voice registry refreshed: {len(self.voices)} custom voice(s) available")
|
||||
return len(self.voices)
|
||||
|
||||
def get(self, name: str) -> CustomVoice | None:
|
||||
"""Get a custom voice by name."""
|
||||
return self.voices.get(name)
|
||||
|
||||
def list_voices(self) -> list[dict]:
|
||||
"""List all available custom voices as serialisable dicts."""
|
||||
return [
|
||||
{
|
||||
"name": v.name,
|
||||
"language": v.language,
|
||||
"model_type": v.model_type,
|
||||
"created_at": v.created_at,
|
||||
}
|
||||
for v in self.voices.values()
|
||||
]
|
||||
|
||||
|
||||
class StreamingTTS:
|
||||
"""Streaming Text-to-Speech service using Coqui XTTS."""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
self.nc = None
|
||||
self.js = None
|
||||
@@ -131,31 +249,32 @@ class StreamingTTS:
|
||||
self.meter = None
|
||||
self.synthesis_counter = None
|
||||
self.synthesis_duration = None
|
||||
self.active_sessions: Dict[str, dict] = {}
|
||||
|
||||
self.active_sessions: dict[str, dict] = {}
|
||||
self.voice_registry = VoiceRegistry(VOICE_MODEL_STORE)
|
||||
|
||||
async def setup(self):
|
||||
"""Initialize connections."""
|
||||
self.tracer, self.meter = setup_telemetry()
|
||||
|
||||
|
||||
if self.meter:
|
||||
self.synthesis_counter = self.meter.create_counter(
|
||||
name="tts_synthesis_total",
|
||||
description="Total number of TTS synthesis requests",
|
||||
unit="1"
|
||||
unit="1",
|
||||
)
|
||||
self.synthesis_duration = self.meter.create_histogram(
|
||||
name="tts_synthesis_duration_seconds",
|
||||
description="Duration of TTS synthesis",
|
||||
unit="s"
|
||||
unit="s",
|
||||
)
|
||||
|
||||
|
||||
# NATS connection
|
||||
self.nc = await nats.connect(NATS_URL)
|
||||
logger.info(f"Connected to NATS at {NATS_URL}")
|
||||
|
||||
|
||||
# Initialize JetStream context
|
||||
self.js = self.nc.jetstream()
|
||||
|
||||
|
||||
# Create or update stream for TTS messages
|
||||
try:
|
||||
stream_config = nats.js.api.StreamConfig(
|
||||
@@ -169,63 +288,88 @@ class StreamingTTS:
|
||||
logger.info("Created/updated JetStream stream: AI_VOICE_TTS")
|
||||
except Exception as e:
|
||||
logger.info(f"JetStream stream setup: {e}")
|
||||
|
||||
|
||||
# HTTP client for XTTS service
|
||||
self.http_client = httpx.AsyncClient(timeout=180.0)
|
||||
logger.info("HTTP client initialized")
|
||||
|
||||
|
||||
# Discover custom voices from model store
|
||||
self.voice_registry.refresh()
|
||||
|
||||
self.is_healthy = True
|
||||
|
||||
async def synthesize(self, text: str, speaker: str = None, language: str = None,
|
||||
speaker_wav_b64: str = None) -> Optional[bytes]:
|
||||
"""Synthesize speech using XTTS API."""
|
||||
|
||||
async def synthesize(
|
||||
self,
|
||||
text: str,
|
||||
speaker: str = None,
|
||||
language: str = None,
|
||||
speaker_wav_b64: str = None,
|
||||
) -> bytes | None:
|
||||
"""Synthesize speech using XTTS API.
|
||||
|
||||
When *speaker* matches a custom voice from the voice registry the
|
||||
request is enriched with the trained model path so the XTTS backend
|
||||
loads the fine-tuned model instead of the default one.
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
|
||||
try:
|
||||
# Check if speaker matches a custom trained voice
|
||||
custom_voice = self.voice_registry.get(speaker) if speaker else None
|
||||
|
||||
# Build request payload
|
||||
payload = {
|
||||
"text": text,
|
||||
"speaker": speaker or DEFAULT_SPEAKER,
|
||||
"language": language or DEFAULT_LANGUAGE,
|
||||
}
|
||||
|
||||
# Add speaker reference audio for voice cloning if provided
|
||||
if speaker_wav_b64:
|
||||
|
||||
if custom_voice:
|
||||
# Custom voice from coqui-voice-training pipeline
|
||||
payload["model_path"] = custom_voice.model_path
|
||||
if custom_voice.config_path:
|
||||
payload["config_path"] = custom_voice.config_path
|
||||
payload["language"] = language or custom_voice.language
|
||||
logger.info(
|
||||
f"Using custom voice '{custom_voice.name}' from {custom_voice.model_path}"
|
||||
)
|
||||
elif speaker_wav_b64:
|
||||
# Ad-hoc voice cloning via reference audio
|
||||
payload["speaker_wav"] = speaker_wav_b64
|
||||
|
||||
|
||||
# Call XTTS API
|
||||
response = await self.http_client.post(
|
||||
f"{XTTS_URL}/v1/audio/speech",
|
||||
json=payload
|
||||
)
|
||||
response = await self.http_client.post(f"{XTTS_URL}/v1/audio/speech", json=payload)
|
||||
response.raise_for_status()
|
||||
|
||||
|
||||
audio_bytes = response.content
|
||||
|
||||
|
||||
duration = time.time() - start_time
|
||||
audio_duration = len(audio_bytes) / (SAMPLE_RATE * 2) # 16-bit audio
|
||||
rtf = duration / audio_duration if audio_duration > 0 else 0
|
||||
|
||||
logger.info(f"Synthesized {len(audio_bytes)} bytes in {duration:.2f}s (RTF: {rtf:.2f})")
|
||||
|
||||
|
||||
voice_label = custom_voice.name if custom_voice else (speaker or DEFAULT_SPEAKER)
|
||||
logger.info(
|
||||
f"Synthesized {len(audio_bytes)} bytes in {duration:.2f}s (RTF: {rtf:.2f}, voice: {voice_label})"
|
||||
)
|
||||
|
||||
if self.synthesis_duration:
|
||||
self.synthesis_duration.record(duration, {"speaker": speaker or DEFAULT_SPEAKER})
|
||||
|
||||
self.synthesis_duration.record(duration, {"speaker": voice_label})
|
||||
|
||||
return audio_bytes
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Synthesis failed: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def stream_audio(self, session_id: str, audio_bytes: bytes):
|
||||
"""Stream audio back to client in chunks."""
|
||||
total_chunks = (len(audio_bytes) + AUDIO_CHUNK_SIZE - 1) // AUDIO_CHUNK_SIZE
|
||||
|
||||
|
||||
for i in range(0, len(audio_bytes), AUDIO_CHUNK_SIZE):
|
||||
chunk = audio_bytes[i:i + AUDIO_CHUNK_SIZE]
|
||||
chunk = audio_bytes[i : i + AUDIO_CHUNK_SIZE]
|
||||
chunk_index = i // AUDIO_CHUNK_SIZE
|
||||
is_last = (i + AUDIO_CHUNK_SIZE) >= len(audio_bytes)
|
||||
|
||||
|
||||
message = {
|
||||
"session_id": session_id,
|
||||
"chunk_index": chunk_index,
|
||||
@@ -235,52 +379,51 @@ class StreamingTTS:
|
||||
"timestamp": time.time(),
|
||||
"sample_rate": SAMPLE_RATE,
|
||||
}
|
||||
|
||||
await self.nc.publish(
|
||||
f"{AUDIO_SUBJECT_PREFIX}.{session_id}",
|
||||
msgpack.packb(message)
|
||||
)
|
||||
|
||||
|
||||
await self.nc.publish(f"{AUDIO_SUBJECT_PREFIX}.{session_id}", msgpack.packb(message))
|
||||
|
||||
logger.debug(f"Sent chunk {chunk_index + 1}/{total_chunks} for session {session_id}")
|
||||
|
||||
|
||||
logger.info(f"Streamed {total_chunks} chunks for session {session_id}")
|
||||
|
||||
|
||||
async def handle_request(self, msg: Msg):
|
||||
"""Handle incoming TTS request."""
|
||||
try:
|
||||
# Extract session_id from subject: ai.voice.tts.request.{session_id}
|
||||
subject_parts = msg.subject.split('.')
|
||||
subject_parts = msg.subject.split(".")
|
||||
if len(subject_parts) < 5:
|
||||
logger.warning(f"Invalid subject format: {msg.subject}")
|
||||
return
|
||||
|
||||
|
||||
session_id = subject_parts[4]
|
||||
|
||||
|
||||
# Parse request using msgpack
|
||||
data = msgpack.unpackb(msg.data, raw=False)
|
||||
|
||||
|
||||
text = data.get("text", "")
|
||||
speaker = data.get("speaker")
|
||||
language = data.get("language")
|
||||
speaker_wav_b64 = data.get("speaker_wav_b64") # For voice cloning
|
||||
stream = data.get("stream", True) # Default to streaming
|
||||
|
||||
|
||||
if not text:
|
||||
logger.warning(f"Empty text for session {session_id}")
|
||||
await self.publish_status(session_id, "error", "Empty text provided")
|
||||
return
|
||||
|
||||
|
||||
logger.info(f"Processing TTS request for session {session_id}: {text[:50]}...")
|
||||
|
||||
|
||||
if self.synthesis_counter:
|
||||
self.synthesis_counter.add(1, {"session_id": session_id})
|
||||
|
||||
|
||||
# Publish status: processing
|
||||
await self.publish_status(session_id, "processing", f"Synthesizing {len(text)} characters")
|
||||
|
||||
await self.publish_status(
|
||||
session_id, "processing", f"Synthesizing {len(text)} characters"
|
||||
)
|
||||
|
||||
# Synthesize audio
|
||||
audio_bytes = await self.synthesize(text, speaker, language, speaker_wav_b64)
|
||||
|
||||
|
||||
if audio_bytes:
|
||||
if stream:
|
||||
# Stream audio in chunks
|
||||
@@ -294,21 +437,20 @@ class StreamingTTS:
|
||||
"sample_rate": SAMPLE_RATE,
|
||||
}
|
||||
await self.nc.publish(
|
||||
f"{AUDIO_SUBJECT_PREFIX}.{session_id}",
|
||||
msgpack.packb(message)
|
||||
f"{AUDIO_SUBJECT_PREFIX}.{session_id}", msgpack.packb(message)
|
||||
)
|
||||
|
||||
await self.publish_status(session_id, "completed", f"Audio size: {len(audio_bytes)} bytes")
|
||||
|
||||
await self.publish_status(
|
||||
session_id, "completed", f"Audio size: {len(audio_bytes)} bytes"
|
||||
)
|
||||
else:
|
||||
await self.publish_status(session_id, "error", "Synthesis failed")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling TTS request: {e}", exc_info=True)
|
||||
try:
|
||||
with contextlib.suppress(Exception):
|
||||
await self.publish_status(session_id, "error", str(e))
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
async def publish_status(self, session_id: str, status: str, message: str = ""):
|
||||
"""Publish TTS status update."""
|
||||
status_msg = {
|
||||
@@ -317,35 +459,79 @@ class StreamingTTS:
|
||||
"message": message,
|
||||
"timestamp": time.time(),
|
||||
}
|
||||
await self.nc.publish(
|
||||
f"{STATUS_SUBJECT_PREFIX}.{session_id}",
|
||||
msgpack.packb(status_msg)
|
||||
)
|
||||
await self.nc.publish(f"{STATUS_SUBJECT_PREFIX}.{session_id}", msgpack.packb(status_msg))
|
||||
logger.debug(f"Published status '{status}' for session {session_id}")
|
||||
|
||||
|
||||
async def handle_list_voices(self, msg: Msg):
|
||||
"""Handle request to list available voices (built-in + custom)."""
|
||||
custom = self.voice_registry.list_voices()
|
||||
response = {
|
||||
"default_speaker": DEFAULT_SPEAKER,
|
||||
"custom_voices": custom,
|
||||
"last_refresh": self.voice_registry._last_refresh,
|
||||
"timestamp": time.time(),
|
||||
}
|
||||
if msg.reply:
|
||||
await msg.respond(msgpack.packb(response))
|
||||
logger.debug(f"Listed {len(custom)} custom voice(s)")
|
||||
|
||||
async def handle_refresh_voices(self, msg: Msg):
|
||||
"""Handle request to refresh the custom voice registry."""
|
||||
count = self.voice_registry.refresh()
|
||||
response = {
|
||||
"count": count,
|
||||
"custom_voices": self.voice_registry.list_voices(),
|
||||
"timestamp": time.time(),
|
||||
}
|
||||
if msg.reply:
|
||||
await msg.respond(msgpack.packb(response))
|
||||
logger.info(f"Voice registry refreshed on demand: {count} voice(s)")
|
||||
|
||||
async def _periodic_voice_refresh(self):
|
||||
"""Periodically refresh the voice registry to pick up newly trained voices."""
|
||||
while self.running:
|
||||
await asyncio.sleep(VOICE_REGISTRY_REFRESH_SECONDS)
|
||||
if not self.running:
|
||||
break
|
||||
try:
|
||||
self.voice_registry.refresh()
|
||||
except Exception as e:
|
||||
logger.error(f"Periodic voice registry refresh failed: {e}")
|
||||
|
||||
async def run(self):
|
||||
"""Main run loop."""
|
||||
await self.setup()
|
||||
|
||||
|
||||
# Subscribe to TTS requests
|
||||
sub = await self.nc.subscribe(f"{REQUEST_SUBJECT_PREFIX}.>", cb=self.handle_request)
|
||||
logger.info(f"Subscribed to {REQUEST_SUBJECT_PREFIX}.>")
|
||||
|
||||
|
||||
# Subscribe to voice management subjects
|
||||
voices_sub = await self.nc.subscribe(VOICES_LIST_SUBJECT, cb=self.handle_list_voices)
|
||||
refresh_sub = await self.nc.subscribe(VOICES_REFRESH_SUBJECT, cb=self.handle_refresh_voices)
|
||||
logger.info(f"Subscribed to {VOICES_LIST_SUBJECT} and {VOICES_REFRESH_SUBJECT}")
|
||||
|
||||
# Start periodic voice registry refresh
|
||||
refresh_task = asyncio.create_task(self._periodic_voice_refresh())
|
||||
|
||||
# Handle shutdown
|
||||
def signal_handler():
|
||||
self.running = False
|
||||
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
for sig in (signal.SIGTERM, signal.SIGINT):
|
||||
loop.add_signal_handler(sig, signal_handler)
|
||||
|
||||
|
||||
# Keep running
|
||||
while self.running:
|
||||
await asyncio.sleep(1)
|
||||
|
||||
|
||||
# Cleanup
|
||||
logger.info("Shutting down...")
|
||||
refresh_task.cancel()
|
||||
await sub.unsubscribe()
|
||||
await voices_sub.unsubscribe()
|
||||
await refresh_sub.unsubscribe()
|
||||
await self.nc.close()
|
||||
await self.http_client.aclose()
|
||||
logger.info("Shutdown complete")
|
||||
|
||||
Reference in New Issue
Block a user