feat: Add voice assistant handler and Kubeflow pipeline
- voice_assistant.py: Standalone NATS handler with full RAG pipeline - voice_assistant_v2.py: Handler-base implementation - pipelines/voice_pipeline.py: KFP SDK pipeline definitions - Dockerfiles for both standalone and handler-base versions Pipeline: STT → Embeddings → Milvus → Rerank → LLM → TTS
This commit is contained in:
42
.gitignore
vendored
Normal file
42
.gitignore
vendored
Normal file
@@ -0,0 +1,42 @@
|
||||
# Python
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
*.so
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
|
||||
# Virtual environments
|
||||
venv/
|
||||
.venv/
|
||||
ENV/
|
||||
|
||||
# IDE
|
||||
.idea/
|
||||
.vscode/
|
||||
*.swp
|
||||
*.swo
|
||||
*~
|
||||
|
||||
# Compiled KFP pipelines
|
||||
*.yaml
|
||||
!pipelines/*.py
|
||||
|
||||
# Local
|
||||
.env
|
||||
.env.local
|
||||
*.log
|
||||
29
Dockerfile
Normal file
29
Dockerfile
Normal file
@@ -0,0 +1,29 @@
|
||||
FROM python:3.13-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Install uv for fast, reliable package management
|
||||
COPY --from=ghcr.io/astral-sh/uv:latest /uv /usr/local/bin/uv
|
||||
|
||||
# Install system dependencies
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
curl \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Copy requirements first for better caching
|
||||
COPY requirements.txt .
|
||||
RUN uv pip install --system --no-cache -r requirements.txt
|
||||
|
||||
# Copy application code
|
||||
COPY voice_assistant.py .
|
||||
|
||||
# Set environment variables
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
ENV PYTHONDONTWRITEBYTECODE=1
|
||||
|
||||
# Health check
|
||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
|
||||
CMD python -c "print('healthy')" || exit 1
|
||||
|
||||
# Run the application
|
||||
CMD ["python", "voice_assistant.py"]
|
||||
9
Dockerfile.v2
Normal file
9
Dockerfile.v2
Normal file
@@ -0,0 +1,9 @@
|
||||
# Voice Assistant v2 - Using handler-base with audio support
|
||||
ARG BASE_TAG=local-audio
|
||||
FROM ghcr.io/billy-davies-2/handler-base:${BASE_TAG}
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY voice_assistant_v2.py ./voice_assistant.py
|
||||
|
||||
CMD ["python", "voice_assistant.py"]
|
||||
121
README.md
121
README.md
@@ -1,3 +1,120 @@
|
||||
# voice-assistant
|
||||
# Voice Assistant
|
||||
|
||||
voice assistance please
|
||||
End-to-end voice assistant pipeline for the DaviesTechLabs AI/ML platform.
|
||||
|
||||
## Components
|
||||
|
||||
### Real-time Handler (NATS-based)
|
||||
|
||||
The voice assistant service listens on NATS for audio requests and returns synthesized speech responses.
|
||||
|
||||
**Pipeline:** STT → Embeddings → Milvus RAG → Rerank → LLM → TTS
|
||||
|
||||
| File | Description |
|
||||
|------|-------------|
|
||||
| `voice_assistant.py` | Standalone handler (v1) |
|
||||
| `voice_assistant_v2.py` | Handler using handler-base library |
|
||||
| `Dockerfile` | Standalone image |
|
||||
| `Dockerfile.v2` | Handler-base image |
|
||||
|
||||
### Kubeflow Pipeline (Batch)
|
||||
|
||||
For batch processing or async workflows via Kubeflow Pipelines.
|
||||
|
||||
| Pipeline | Description |
|
||||
|----------|-------------|
|
||||
| `voice_pipeline.yaml` | Full STT → RAG → TTS pipeline |
|
||||
| `rag_pipeline.yaml` | Text-only RAG pipeline |
|
||||
| `tts_pipeline.yaml` | Simple text-to-speech |
|
||||
|
||||
```bash
|
||||
# Compile pipelines
|
||||
cd pipelines
|
||||
pip install kfp==2.12.1
|
||||
python voice_pipeline.py
|
||||
```
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
NATS (voice.request)
|
||||
│
|
||||
▼
|
||||
┌───────────────────┐
|
||||
│ Voice Assistant │
|
||||
│ Handler │
|
||||
└───────────────────┘
|
||||
│
|
||||
├──▶ Whisper STT (elminster)
|
||||
│ │
|
||||
│ ▼
|
||||
├──▶ BGE Embeddings (drizzt)
|
||||
│ │
|
||||
│ ▼
|
||||
├──▶ Milvus Vector Search
|
||||
│ │
|
||||
│ ▼
|
||||
├──▶ BGE Reranker (danilo)
|
||||
│ │
|
||||
│ ▼
|
||||
├──▶ vLLM (khelben)
|
||||
│ │
|
||||
│ ▼
|
||||
└──▶ XTTS TTS (elminster)
|
||||
│
|
||||
▼
|
||||
NATS (voice.response.{id})
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
| Environment Variable | Default | Description |
|
||||
|---------------------|---------|-------------|
|
||||
| `NATS_URL` | `nats://nats.ai-ml.svc.cluster.local:4222` | NATS server |
|
||||
| `WHISPER_URL` | `http://whisper-predictor.ai-ml.svc.cluster.local` | STT service |
|
||||
| `EMBEDDINGS_URL` | `http://embeddings-predictor.ai-ml.svc.cluster.local` | Embeddings |
|
||||
| `RERANKER_URL` | `http://reranker-predictor.ai-ml.svc.cluster.local` | Reranker |
|
||||
| `VLLM_URL` | `http://llm-draft.ai-ml.svc.cluster.local:8000` | LLM service |
|
||||
| `TTS_URL` | `http://tts-predictor.ai-ml.svc.cluster.local` | TTS service |
|
||||
| `MILVUS_HOST` | `milvus.ai-ml.svc.cluster.local` | Vector DB |
|
||||
| `COLLECTION_NAME` | `knowledge_base` | Milvus collection |
|
||||
|
||||
## NATS Message Format
|
||||
|
||||
### Request (voice.request)
|
||||
|
||||
```json
|
||||
{
|
||||
"request_id": "uuid",
|
||||
"audio": "base64-encoded-audio",
|
||||
"language": "en",
|
||||
"collection": "knowledge_base"
|
||||
}
|
||||
```
|
||||
|
||||
### Response (voice.response.{request_id})
|
||||
|
||||
```json
|
||||
{
|
||||
"request_id": "uuid",
|
||||
"transcription": "user question",
|
||||
"response": "assistant answer",
|
||||
"audio": "base64-encoded-audio"
|
||||
}
|
||||
```
|
||||
|
||||
## Building
|
||||
|
||||
```bash
|
||||
# Standalone image (v1)
|
||||
docker build -f Dockerfile -t voice-assistant:latest .
|
||||
|
||||
# Handler-base image (v2 - recommended)
|
||||
docker build -f Dockerfile.v2 -t voice-assistant:v2 .
|
||||
```
|
||||
|
||||
## Related
|
||||
|
||||
- [homelab-design](https://git.daviestechlabs.io/daviestechlabs/homelab-design) - Architecture docs
|
||||
- [kuberay-images](https://git.daviestechlabs.io/daviestechlabs/kuberay-images) - Ray worker images
|
||||
- [handler-base](https://github.com/Billy-Davies-2/llm-workflows/tree/main/handler-base) - Base handler library
|
||||
|
||||
316
pipelines/voice_pipeline.py
Normal file
316
pipelines/voice_pipeline.py
Normal file
@@ -0,0 +1,316 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Voice Pipeline - Kubeflow Pipelines SDK
|
||||
|
||||
Compile this to create a Kubeflow Pipeline for voice assistant workflows.
|
||||
|
||||
Usage:
|
||||
pip install kfp==2.12.1
|
||||
python voice_pipeline.py
|
||||
# Upload voice_pipeline.yaml to Kubeflow Pipelines UI
|
||||
"""
|
||||
|
||||
from kfp import dsl
|
||||
from kfp import compiler
|
||||
|
||||
|
||||
@dsl.component(
|
||||
base_image="python:3.13-slim",
|
||||
packages_to_install=["httpx"]
|
||||
)
|
||||
def transcribe_audio(
|
||||
audio_b64: str,
|
||||
whisper_url: str = "http://whisper-predictor.ai-ml.svc.cluster.local"
|
||||
) -> str:
|
||||
"""Transcribe audio using Whisper STT service."""
|
||||
import base64
|
||||
import httpx
|
||||
|
||||
audio_bytes = base64.b64decode(audio_b64)
|
||||
|
||||
with httpx.Client(timeout=120.0) as client:
|
||||
response = client.post(
|
||||
f"{whisper_url}/v1/audio/transcriptions",
|
||||
files={"file": ("audio.wav", audio_bytes, "audio/wav")},
|
||||
data={"model": "whisper-large-v3", "language": "en"}
|
||||
)
|
||||
result = response.json()
|
||||
|
||||
return result.get("text", "")
|
||||
|
||||
|
||||
@dsl.component(
|
||||
base_image="python:3.13-slim",
|
||||
packages_to_install=["httpx"]
|
||||
)
|
||||
def generate_embeddings(
|
||||
text: str,
|
||||
embeddings_url: str = "http://embeddings-predictor.ai-ml.svc.cluster.local"
|
||||
) -> list:
|
||||
"""Generate embeddings for RAG retrieval."""
|
||||
import httpx
|
||||
|
||||
with httpx.Client(timeout=60.0) as client:
|
||||
response = client.post(
|
||||
f"{embeddings_url}/embeddings",
|
||||
json={"input": text, "model": "bge-small-en-v1.5"}
|
||||
)
|
||||
result = response.json()
|
||||
|
||||
return result["data"][0]["embedding"]
|
||||
|
||||
|
||||
@dsl.component(
|
||||
base_image="python:3.13-slim",
|
||||
packages_to_install=["pymilvus"]
|
||||
)
|
||||
def retrieve_context(
|
||||
embedding: list,
|
||||
milvus_host: str = "milvus.ai-ml.svc.cluster.local",
|
||||
collection_name: str = "knowledge_base",
|
||||
top_k: int = 5
|
||||
) -> list:
|
||||
"""Retrieve relevant documents from Milvus vector database."""
|
||||
from pymilvus import connections, Collection, utility
|
||||
|
||||
connections.connect(host=milvus_host, port=19530)
|
||||
|
||||
if not utility.has_collection(collection_name):
|
||||
return []
|
||||
|
||||
collection = Collection(collection_name)
|
||||
collection.load()
|
||||
|
||||
results = collection.search(
|
||||
data=[embedding],
|
||||
anns_field="embedding",
|
||||
param={"metric_type": "COSINE", "params": {"nprobe": 10}},
|
||||
limit=top_k,
|
||||
output_fields=["text", "source"]
|
||||
)
|
||||
|
||||
documents = []
|
||||
for hits in results:
|
||||
for hit in hits:
|
||||
documents.append({
|
||||
"text": hit.entity.get("text"),
|
||||
"source": hit.entity.get("source"),
|
||||
"score": hit.distance
|
||||
})
|
||||
|
||||
return documents
|
||||
|
||||
|
||||
@dsl.component(
|
||||
base_image="python:3.13-slim",
|
||||
packages_to_install=["httpx"]
|
||||
)
|
||||
def rerank_documents(
|
||||
query: str,
|
||||
documents: list,
|
||||
reranker_url: str = "http://reranker-predictor.ai-ml.svc.cluster.local",
|
||||
top_k: int = 3
|
||||
) -> list:
|
||||
"""Rerank documents using BGE reranker."""
|
||||
import httpx
|
||||
|
||||
if not documents:
|
||||
return []
|
||||
|
||||
with httpx.Client(timeout=60.0) as client:
|
||||
response = client.post(
|
||||
f"{reranker_url}/v1/rerank",
|
||||
json={
|
||||
"query": query,
|
||||
"documents": [doc["text"] for doc in documents],
|
||||
"model": "bge-reranker-v2-m3"
|
||||
}
|
||||
)
|
||||
result = response.json()
|
||||
|
||||
# Sort by rerank score
|
||||
reranked = sorted(
|
||||
zip(documents, result.get("scores", [0] * len(documents))),
|
||||
key=lambda x: x[1],
|
||||
reverse=True
|
||||
)[:top_k]
|
||||
|
||||
return [doc for doc, score in reranked]
|
||||
|
||||
|
||||
@dsl.component(
|
||||
base_image="python:3.13-slim",
|
||||
packages_to_install=["httpx"]
|
||||
)
|
||||
def generate_response(
|
||||
query: str,
|
||||
context: list,
|
||||
vllm_url: str = "http://llm-draft.ai-ml.svc.cluster.local:8000",
|
||||
model: str = "mistralai/Mistral-7B-Instruct-v0.3"
|
||||
) -> str:
|
||||
"""Generate response using vLLM."""
|
||||
import httpx
|
||||
|
||||
# Build context
|
||||
if context:
|
||||
context_text = "\n\n".join([doc["text"] for doc in context])
|
||||
user_content = f"Context:\n{context_text}\n\nQuestion: {query}"
|
||||
else:
|
||||
user_content = query
|
||||
|
||||
system_prompt = """You are a helpful voice assistant.
|
||||
Answer questions based on the provided context when available.
|
||||
Keep responses concise and natural for speech synthesis."""
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_content}
|
||||
]
|
||||
|
||||
with httpx.Client(timeout=180.0) as client:
|
||||
response = client.post(
|
||||
f"{vllm_url}/v1/chat/completions",
|
||||
json={
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"max_tokens": 512,
|
||||
"temperature": 0.7
|
||||
}
|
||||
)
|
||||
result = response.json()
|
||||
|
||||
return result["choices"][0]["message"]["content"]
|
||||
|
||||
|
||||
@dsl.component(
|
||||
base_image="python:3.13-slim",
|
||||
packages_to_install=["httpx"]
|
||||
)
|
||||
def synthesize_speech(
|
||||
text: str,
|
||||
tts_url: str = "http://tts-predictor.ai-ml.svc.cluster.local"
|
||||
) -> str:
|
||||
"""Convert text to speech using TTS service."""
|
||||
import base64
|
||||
import httpx
|
||||
|
||||
with httpx.Client(timeout=120.0) as client:
|
||||
response = client.post(
|
||||
f"{tts_url}/v1/audio/speech",
|
||||
json={
|
||||
"input": text,
|
||||
"voice": "en_US-lessac-high",
|
||||
"response_format": "wav"
|
||||
}
|
||||
)
|
||||
audio_b64 = base64.b64encode(response.content).decode("utf-8")
|
||||
|
||||
return audio_b64
|
||||
|
||||
|
||||
@dsl.pipeline(
|
||||
name="voice-assistant-rag-pipeline",
|
||||
description="End-to-end voice assistant with RAG: STT -> Embeddings -> Milvus -> Rerank -> LLM -> TTS"
|
||||
)
|
||||
def voice_assistant_pipeline(
|
||||
audio_b64: str,
|
||||
collection_name: str = "knowledge_base"
|
||||
):
|
||||
"""
|
||||
Voice Assistant Pipeline with RAG
|
||||
|
||||
Args:
|
||||
audio_b64: Base64-encoded audio file
|
||||
collection_name: Milvus collection for RAG
|
||||
"""
|
||||
|
||||
# Step 1: Transcribe audio with Whisper
|
||||
transcribe_task = transcribe_audio(audio_b64=audio_b64)
|
||||
transcribe_task.set_caching_options(enable_caching=False)
|
||||
|
||||
# Step 2: Generate embeddings
|
||||
embed_task = generate_embeddings(text=transcribe_task.output)
|
||||
embed_task.set_caching_options(enable_caching=True)
|
||||
|
||||
# Step 3: Retrieve context from Milvus
|
||||
retrieve_task = retrieve_context(
|
||||
embedding=embed_task.output,
|
||||
collection_name=collection_name
|
||||
)
|
||||
|
||||
# Step 4: Rerank documents
|
||||
rerank_task = rerank_documents(
|
||||
query=transcribe_task.output,
|
||||
documents=retrieve_task.output
|
||||
)
|
||||
|
||||
# Step 5: Generate response with context
|
||||
llm_task = generate_response(
|
||||
query=transcribe_task.output,
|
||||
context=rerank_task.output
|
||||
)
|
||||
|
||||
# Step 6: Synthesize speech
|
||||
tts_task = synthesize_speech(text=llm_task.output)
|
||||
|
||||
|
||||
@dsl.pipeline(
|
||||
name="text-to-speech-pipeline",
|
||||
description="Simple text to speech pipeline"
|
||||
)
|
||||
def text_to_speech_pipeline(text: str):
|
||||
"""Simple TTS pipeline for testing."""
|
||||
tts_task = synthesize_speech(text=text)
|
||||
|
||||
|
||||
@dsl.pipeline(
|
||||
name="rag-query-pipeline",
|
||||
description="RAG query pipeline: Embed -> Retrieve -> Rerank -> LLM"
|
||||
)
|
||||
def rag_query_pipeline(
|
||||
query: str,
|
||||
collection_name: str = "knowledge_base"
|
||||
):
|
||||
"""
|
||||
RAG Query Pipeline (text input, no voice)
|
||||
|
||||
Args:
|
||||
query: Text query
|
||||
collection_name: Milvus collection name
|
||||
"""
|
||||
# Embed the query
|
||||
embed_task = generate_embeddings(text=query)
|
||||
|
||||
# Retrieve from Milvus
|
||||
retrieve_task = retrieve_context(
|
||||
embedding=embed_task.output,
|
||||
collection_name=collection_name
|
||||
)
|
||||
|
||||
# Rerank
|
||||
rerank_task = rerank_documents(
|
||||
query=query,
|
||||
documents=retrieve_task.output
|
||||
)
|
||||
|
||||
# Generate response
|
||||
llm_task = generate_response(
|
||||
query=query,
|
||||
context=rerank_task.output
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Compile all pipelines
|
||||
pipelines = [
|
||||
("voice_pipeline.yaml", voice_assistant_pipeline),
|
||||
("tts_pipeline.yaml", text_to_speech_pipeline),
|
||||
("rag_pipeline.yaml", rag_query_pipeline),
|
||||
]
|
||||
|
||||
for filename, pipeline_func in pipelines:
|
||||
compiler.Compiler().compile(pipeline_func, filename)
|
||||
print(f"Compiled: {filename}")
|
||||
|
||||
print("\nUpload these YAML files to Kubeflow Pipelines UI at:")
|
||||
print(" http://kubeflow.example.com/pipelines")
|
||||
15
requirements.txt
Normal file
15
requirements.txt
Normal file
@@ -0,0 +1,15 @@
|
||||
nats-py
|
||||
httpx
|
||||
pymilvus
|
||||
numpy
|
||||
msgpack
|
||||
redis>=5.0.0
|
||||
opentelemetry-api
|
||||
opentelemetry-sdk
|
||||
opentelemetry-exporter-otlp-proto-grpc
|
||||
opentelemetry-exporter-otlp-proto-http
|
||||
opentelemetry-instrumentation-httpx
|
||||
opentelemetry-instrumentation-logging
|
||||
# MLflow for inference metrics tracking
|
||||
mlflow>=2.10.0
|
||||
psycopg2-binary>=2.9.0
|
||||
875
voice_assistant.py
Normal file
875
voice_assistant.py
Normal file
@@ -0,0 +1,875 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Voice Assistant Service
|
||||
|
||||
End-to-end voice assistant pipeline:
|
||||
1. Listen for audio on NATS subject "voice.request"
|
||||
2. Transcribe with Whisper (STT)
|
||||
3. Generate embeddings for RAG
|
||||
4. Retrieve context from Milvus
|
||||
5. Rerank with BGE reranker
|
||||
6. Generate response with vLLM
|
||||
7. Synthesize speech with XTTS
|
||||
8. Publish result to NATS "voice.response"
|
||||
"""
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from typing import List, Dict, Optional
|
||||
|
||||
# Install dependencies on startup
|
||||
subprocess.check_call([
|
||||
sys.executable, "-m", "pip", "install", "-q",
|
||||
"-r", "/app/requirements.txt"
|
||||
])
|
||||
|
||||
import httpx
|
||||
import msgpack
|
||||
import nats
|
||||
import redis.asyncio as redis
|
||||
from pymilvus import connections, Collection, utility
|
||||
|
||||
# OpenTelemetry imports
|
||||
from opentelemetry import trace, metrics
|
||||
from opentelemetry.sdk.trace import TracerProvider
|
||||
from opentelemetry.sdk.trace.export import BatchSpanProcessor
|
||||
from opentelemetry.sdk.metrics import MeterProvider
|
||||
from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader
|
||||
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
|
||||
from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import OTLPMetricExporter
|
||||
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter as OTLPSpanExporterHTTP
|
||||
from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter as OTLPMetricExporterHTTP
|
||||
from opentelemetry.sdk.resources import Resource, SERVICE_NAME, SERVICE_VERSION, SERVICE_NAMESPACE
|
||||
from opentelemetry.instrumentation.httpx import HTTPXClientInstrumentor
|
||||
from opentelemetry.instrumentation.logging import LoggingInstrumentor
|
||||
|
||||
# MLflow inference tracking
|
||||
try:
|
||||
from mlflow_utils import InferenceMetricsTracker
|
||||
from mlflow_utils.inference_tracker import InferenceMetrics
|
||||
MLFLOW_AVAILABLE = True
|
||||
except ImportError:
|
||||
MLFLOW_AVAILABLE = False
|
||||
InferenceMetricsTracker = None
|
||||
InferenceMetrics = None
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
logger = logging.getLogger("voice-assistant")
|
||||
|
||||
|
||||
def setup_telemetry():
|
||||
"""Initialize OpenTelemetry tracing and metrics."""
|
||||
otel_enabled = os.environ.get("OTEL_ENABLED", "true").lower() == "true"
|
||||
if not otel_enabled:
|
||||
logger.info("OpenTelemetry disabled")
|
||||
return None, None
|
||||
|
||||
otel_endpoint = os.environ.get("OTEL_EXPORTER_OTLP_ENDPOINT", "http://opentelemetry-collector.observability.svc.cluster.local:4317")
|
||||
service_name = os.environ.get("OTEL_SERVICE_NAME", "voice-assistant")
|
||||
service_namespace = os.environ.get("OTEL_SERVICE_NAMESPACE", "ai-ml")
|
||||
|
||||
# HyperDX configuration
|
||||
hyperdx_api_key = os.environ.get("HYPERDX_API_KEY", "")
|
||||
hyperdx_endpoint = os.environ.get("HYPERDX_ENDPOINT", "https://in-otel.hyperdx.io")
|
||||
use_hyperdx = os.environ.get("HYPERDX_ENABLED", "false").lower() == "true" and hyperdx_api_key
|
||||
|
||||
resource = Resource.create({
|
||||
SERVICE_NAME: service_name,
|
||||
SERVICE_VERSION: os.environ.get("SERVICE_VERSION", "1.0.0"),
|
||||
SERVICE_NAMESPACE: service_namespace,
|
||||
"deployment.environment": os.environ.get("DEPLOYMENT_ENV", "production"),
|
||||
"host.name": os.environ.get("HOSTNAME", "unknown"),
|
||||
})
|
||||
|
||||
trace_provider = TracerProvider(resource=resource)
|
||||
|
||||
if use_hyperdx:
|
||||
logger.info(f"Configuring HyperDX exporter at {hyperdx_endpoint}")
|
||||
headers = {"authorization": hyperdx_api_key}
|
||||
otlp_span_exporter = OTLPSpanExporterHTTP(
|
||||
endpoint=f"{hyperdx_endpoint}/v1/traces",
|
||||
headers=headers
|
||||
)
|
||||
otlp_metric_exporter = OTLPMetricExporterHTTP(
|
||||
endpoint=f"{hyperdx_endpoint}/v1/metrics",
|
||||
headers=headers
|
||||
)
|
||||
else:
|
||||
otlp_span_exporter = OTLPSpanExporter(endpoint=otel_endpoint, insecure=True)
|
||||
otlp_metric_exporter = OTLPMetricExporter(endpoint=otel_endpoint, insecure=True)
|
||||
|
||||
trace_provider.add_span_processor(BatchSpanProcessor(otlp_span_exporter))
|
||||
trace.set_tracer_provider(trace_provider)
|
||||
|
||||
metric_reader = PeriodicExportingMetricReader(otlp_metric_exporter, export_interval_millis=60000)
|
||||
meter_provider = MeterProvider(resource=resource, metric_readers=[metric_reader])
|
||||
metrics.set_meter_provider(meter_provider)
|
||||
|
||||
HTTPXClientInstrumentor().instrument()
|
||||
LoggingInstrumentor().instrument(set_logging_format=True)
|
||||
|
||||
destination = "HyperDX" if use_hyperdx else "OTEL Collector"
|
||||
logger.info(f"OpenTelemetry initialized - destination: {destination}, service: {service_name}")
|
||||
|
||||
return trace.get_tracer(__name__), metrics.get_meter(__name__)
|
||||
|
||||
# Configuration from environment
|
||||
WHISPER_URL = os.environ.get(
|
||||
"WHISPER_URL", "http://whisper-predictor.ai-ml.svc.cluster.local"
|
||||
)
|
||||
TTS_URL = os.environ.get("TTS_URL", "http://tts-predictor.ai-ml.svc.cluster.local:5002")
|
||||
EMBEDDINGS_URL = os.environ.get(
|
||||
"EMBEDDINGS_URL", "http://embeddings-predictor.ai-ml.svc.cluster.local"
|
||||
)
|
||||
RERANKER_URL = os.environ.get(
|
||||
"RERANKER_URL", "http://reranker-predictor.ai-ml.svc.cluster.local"
|
||||
)
|
||||
VLLM_URL = os.environ.get("VLLM_URL", "http://llm-draft.ai-ml.svc.cluster.local:8000")
|
||||
LLM_MODEL = os.environ.get("LLM_MODEL", "mistralai/Mistral-7B-Instruct-v0.3")
|
||||
MILVUS_HOST = os.environ.get("MILVUS_HOST", "milvus.ai-ml.svc.cluster.local")
|
||||
MILVUS_PORT = int(os.environ.get("MILVUS_PORT", "19530"))
|
||||
COLLECTION_NAME = os.environ.get("COLLECTION_NAME", "knowledge_base")
|
||||
NATS_URL = os.environ.get("NATS_URL", "nats://nats.ai-ml.svc.cluster.local:4222")
|
||||
VALKEY_URL = os.environ.get("VALKEY_URL", "redis://valkey.ai-ml.svc.cluster.local:6379")
|
||||
|
||||
# MLflow configuration
|
||||
MLFLOW_ENABLED = os.environ.get("MLFLOW_ENABLED", "true").lower() == "true"
|
||||
MLFLOW_TRACKING_URI = os.environ.get(
|
||||
"MLFLOW_TRACKING_URI", "http://mlflow.mlflow.svc.cluster.local:80"
|
||||
)
|
||||
|
||||
# Context window limits (characters)
|
||||
MAX_CONTEXT_LENGTH = int(os.environ.get("MAX_CONTEXT_LENGTH", "8000")) # Prevent unbounded growth
|
||||
|
||||
# NATS subjects (ai.* schema)
|
||||
# Per-user channels matching companions-frontend pattern
|
||||
REQUEST_SUBJECT = "ai.voice.user.*.request" # Wildcard subscription for all users
|
||||
PREMIUM_REQUEST_SUBJECT = "ai.voice.premium.user.*.request" # Premium users
|
||||
RESPONSE_SUBJECT = "ai.voice.response" # Response published to specific request_id
|
||||
STREAM_RESPONSE_SUBJECT = "ai.voice.response.stream" # Streaming responses (token chunks)
|
||||
|
||||
# System prompt for the assistant
|
||||
SYSTEM_PROMPT = """You are a helpful voice assistant.
|
||||
Answer questions based on the provided context when available.
|
||||
Keep responses concise and natural for speech synthesis.
|
||||
If you don't know the answer, say so clearly."""
|
||||
|
||||
|
||||
class VoiceAssistant:
|
||||
def __init__(self):
|
||||
self.nc = None
|
||||
self.http_client = None
|
||||
self.collection = None
|
||||
self.valkey_client = None
|
||||
self.running = True
|
||||
self.tracer = None
|
||||
self.meter = None
|
||||
self.request_counter = None
|
||||
self.request_duration = None
|
||||
self.stt_duration = None
|
||||
self.tts_duration = None
|
||||
# MLflow inference tracker
|
||||
self.mlflow_tracker = None
|
||||
|
||||
async def setup(self):
|
||||
"""Initialize all connections."""
|
||||
# Initialize OpenTelemetry
|
||||
self.tracer, self.meter = setup_telemetry()
|
||||
|
||||
# Setup metrics
|
||||
if self.meter:
|
||||
self.request_counter = self.meter.create_counter(
|
||||
"voice.requests",
|
||||
description="Number of voice requests processed",
|
||||
unit="1"
|
||||
)
|
||||
self.request_duration = self.meter.create_histogram(
|
||||
"voice.request_duration",
|
||||
description="Duration of voice request processing",
|
||||
unit="s"
|
||||
)
|
||||
self.stt_duration = self.meter.create_histogram(
|
||||
"voice.stt_duration",
|
||||
description="Duration of speech-to-text processing",
|
||||
unit="s"
|
||||
)
|
||||
self.tts_duration = self.meter.create_histogram(
|
||||
"voice.tts_duration",
|
||||
description="Duration of text-to-speech processing",
|
||||
unit="s"
|
||||
)
|
||||
|
||||
# Initialize MLflow inference tracker
|
||||
if MLFLOW_ENABLED and MLFLOW_AVAILABLE:
|
||||
try:
|
||||
self.mlflow_tracker = InferenceMetricsTracker(
|
||||
service_name="voice-assistant",
|
||||
experiment_name="voice-inference",
|
||||
tracking_uri=MLFLOW_TRACKING_URI,
|
||||
batch_size=50,
|
||||
flush_interval_seconds=60.0,
|
||||
)
|
||||
await self.mlflow_tracker.start()
|
||||
logger.info(f"MLflow inference tracking enabled at {MLFLOW_TRACKING_URI}")
|
||||
except Exception as e:
|
||||
logger.warning(f"MLflow initialization failed: {e}, tracking disabled")
|
||||
self.mlflow_tracker = None
|
||||
elif not MLFLOW_AVAILABLE:
|
||||
logger.info("MLflow utils not available, inference tracking disabled")
|
||||
else:
|
||||
logger.info("MLflow tracking disabled via MLFLOW_ENABLED=false")
|
||||
|
||||
# NATS connection
|
||||
self.nc = await nats.connect(NATS_URL)
|
||||
logger.info(f"Connected to NATS at {NATS_URL}")
|
||||
|
||||
# HTTP client for services
|
||||
self.http_client = httpx.AsyncClient(timeout=180.0)
|
||||
|
||||
# Connect to Valkey for conversation history and context caching
|
||||
try:
|
||||
self.valkey_client = redis.from_url(
|
||||
VALKEY_URL,
|
||||
encoding="utf-8",
|
||||
decode_responses=True,
|
||||
socket_connect_timeout=5
|
||||
)
|
||||
await self.valkey_client.ping()
|
||||
logger.info(f"Connected to Valkey at {VALKEY_URL}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Valkey connection failed: {e}, conversation history disabled")
|
||||
self.valkey_client = None
|
||||
|
||||
# Connect to Milvus if collection exists
|
||||
try:
|
||||
connections.connect(host=MILVUS_HOST, port=MILVUS_PORT)
|
||||
if utility.has_collection(COLLECTION_NAME):
|
||||
self.collection = Collection(COLLECTION_NAME)
|
||||
self.collection.load()
|
||||
logger.info(f"Connected to Milvus collection: {COLLECTION_NAME}")
|
||||
else:
|
||||
logger.warning(f"Collection {COLLECTION_NAME} not found, RAG disabled")
|
||||
except Exception as e:
|
||||
logger.warning(f"Milvus connection failed: {e}, RAG disabled")
|
||||
|
||||
async def transcribe(self, audio_b64: str) -> str:
|
||||
"""Transcribe audio using Whisper."""
|
||||
try:
|
||||
audio_bytes = base64.b64decode(audio_b64)
|
||||
files = {"file": ("audio.wav", audio_bytes, "audio/wav")}
|
||||
response = await self.http_client.post(
|
||||
f"{WHISPER_URL}/v1/audio/transcriptions", files=files
|
||||
)
|
||||
result = response.json()
|
||||
transcript = result.get("text", "")
|
||||
logger.info(f"Transcribed: {transcript[:100]}...")
|
||||
return transcript
|
||||
except Exception as e:
|
||||
logger.error(f"Transcription failed: {e}")
|
||||
return ""
|
||||
|
||||
async def get_embeddings(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Get embeddings from the embedding service."""
|
||||
try:
|
||||
response = await self.http_client.post(
|
||||
f"{EMBEDDINGS_URL}/embeddings", json={"input": texts, "model": "bge"}
|
||||
)
|
||||
result = response.json()
|
||||
return [d["embedding"] for d in result.get("data", [])]
|
||||
except Exception as e:
|
||||
logger.error(f"Embedding failed: {e}")
|
||||
return []
|
||||
|
||||
async def search_milvus(
|
||||
self, query_embedding: List[float], top_k: int = 5
|
||||
) -> List[Dict]:
|
||||
"""Search Milvus for relevant documents."""
|
||||
if not self.collection:
|
||||
return []
|
||||
try:
|
||||
results = self.collection.search(
|
||||
data=[query_embedding],
|
||||
anns_field="embedding",
|
||||
param={"metric_type": "COSINE", "params": {"ef": 64}},
|
||||
limit=top_k,
|
||||
output_fields=["text", "book_name", "page_num"],
|
||||
)
|
||||
docs = []
|
||||
for hits in results:
|
||||
for hit in hits:
|
||||
docs.append(
|
||||
{
|
||||
"text": hit.entity.get("text", ""),
|
||||
"source": f'{hit.entity.get("book_name", "")} p.{hit.entity.get("page_num", "")}',
|
||||
"score": hit.score,
|
||||
}
|
||||
)
|
||||
return docs
|
||||
except Exception as e:
|
||||
logger.error(f"Milvus search failed: {e}")
|
||||
return []
|
||||
|
||||
async def rerank(self, query: str, documents: List[str]) -> List[Dict]:
|
||||
"""Rerank documents using the reranker service."""
|
||||
if not documents:
|
||||
return []
|
||||
try:
|
||||
response = await self.http_client.post(
|
||||
f"{RERANKER_URL}/v1/rerank",
|
||||
json={"query": query, "documents": documents},
|
||||
)
|
||||
return response.json().get("results", [])
|
||||
except Exception as e:
|
||||
logger.error(f"Reranking failed: {e}")
|
||||
return [{"index": i, "relevance_score": 0.5} for i in range(len(documents))]
|
||||
|
||||
async def get_conversation_history(self, session_id: str, max_messages: int = 10) -> List[Dict]:
|
||||
"""Retrieve conversation history from Valkey."""
|
||||
if not self.valkey_client or not session_id:
|
||||
return []
|
||||
try:
|
||||
key = f"voice:history:{session_id}"
|
||||
# Get the most recent messages (stored as a list)
|
||||
history_json = await self.valkey_client.lrange(key, -max_messages, -1)
|
||||
history = [json.loads(msg) for msg in history_json]
|
||||
logger.info(f"Retrieved {len(history)} messages from history for session {session_id}")
|
||||
return history
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get conversation history: {e}")
|
||||
return []
|
||||
|
||||
async def save_message_to_history(self, session_id: str, role: str, content: str, ttl: int = 3600):
|
||||
"""Save a message to conversation history in Valkey."""
|
||||
if not self.valkey_client or not session_id:
|
||||
return
|
||||
try:
|
||||
key = f"voice:history:{session_id}"
|
||||
message = json.dumps({"role": role, "content": content, "timestamp": time.time()})
|
||||
# Use RPUSH to append to the list
|
||||
await self.valkey_client.rpush(key, message)
|
||||
# Set TTL on the key (1 hour by default)
|
||||
await self.valkey_client.expire(key, ttl)
|
||||
logger.debug(f"Saved {role} message to history for session {session_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save message to history: {e}")
|
||||
|
||||
async def get_context_window(self, session_id: str) -> Optional[str]:
|
||||
"""Retrieve cached context window from Valkey for attention offloading."""
|
||||
if not self.valkey_client or not session_id:
|
||||
return None
|
||||
try:
|
||||
key = f"voice:context:{session_id}"
|
||||
context = await self.valkey_client.get(key)
|
||||
if context:
|
||||
logger.info(f"Retrieved cached context window for session {session_id}")
|
||||
return context
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get context window: {e}")
|
||||
return None
|
||||
|
||||
async def save_context_window(self, session_id: str, context: str, ttl: int = 3600):
|
||||
"""Save context window to Valkey for attention offloading."""
|
||||
if not self.valkey_client or not session_id:
|
||||
return
|
||||
try:
|
||||
key = f"voice:context:{session_id}"
|
||||
await self.valkey_client.set(key, context, ex=ttl)
|
||||
logger.debug(f"Saved context window for session {session_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save context window: {e}")
|
||||
|
||||
async def generate_response(self, query: str, context: str = "", session_id: str = None) -> str:
|
||||
"""Generate response using vLLM with conversation history from Valkey."""
|
||||
try:
|
||||
messages = [{"role": "system", "content": SYSTEM_PROMPT}]
|
||||
|
||||
# Add conversation history from Valkey if session exists
|
||||
if session_id:
|
||||
history = await self.get_conversation_history(session_id)
|
||||
messages.extend(history)
|
||||
|
||||
if context:
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"Context:\n{context}\n\nQuestion: {query}",
|
||||
}
|
||||
)
|
||||
else:
|
||||
messages.append({"role": "user", "content": query})
|
||||
|
||||
response = await self.http_client.post(
|
||||
f"{VLLM_URL}/v1/chat/completions",
|
||||
json={
|
||||
"model": LLM_MODEL,
|
||||
"messages": messages,
|
||||
"max_tokens": 500,
|
||||
"temperature": 0.7,
|
||||
},
|
||||
)
|
||||
result = response.json()
|
||||
answer = result["choices"][0]["message"]["content"]
|
||||
logger.info(f"Generated response: {answer[:100]}...")
|
||||
|
||||
# Save messages to conversation history
|
||||
if session_id:
|
||||
await self.save_message_to_history(session_id, "user", query)
|
||||
await self.save_message_to_history(session_id, "assistant", answer)
|
||||
|
||||
return answer
|
||||
except Exception as e:
|
||||
logger.error(f"LLM generation failed: {e}")
|
||||
return "I'm sorry, I couldn't generate a response."
|
||||
|
||||
async def generate_response_streaming(self, query: str, context: str = "", request_id: str = "", session_id: str = None):
|
||||
"""Generate streaming response using vLLM and publish chunks to NATS.
|
||||
|
||||
Yields tokens as they are generated and publishes them to NATS streaming subject.
|
||||
Returns the complete response text.
|
||||
"""
|
||||
try:
|
||||
messages = [{"role": "system", "content": SYSTEM_PROMPT}]
|
||||
|
||||
# Add conversation history from Valkey if session exists
|
||||
if session_id:
|
||||
history = await self.get_conversation_history(session_id)
|
||||
messages.extend(history)
|
||||
|
||||
if context:
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"Context:\n{context}\n\nQuestion: {query}",
|
||||
}
|
||||
)
|
||||
else:
|
||||
messages.append({"role": "user", "content": query})
|
||||
|
||||
full_response = ""
|
||||
|
||||
# Stream response from vLLM
|
||||
async with self.http_client.stream(
|
||||
"POST",
|
||||
f"{VLLM_URL}/v1/chat/completions",
|
||||
json={
|
||||
"model": LLM_MODEL,
|
||||
"messages": messages,
|
||||
"max_tokens": 500,
|
||||
"temperature": 0.7,
|
||||
"stream": True,
|
||||
},
|
||||
timeout=60.0,
|
||||
) as response:
|
||||
# Parse SSE (Server-Sent Events) stream
|
||||
async for line in response.aiter_lines():
|
||||
if not line or not line.startswith("data: "):
|
||||
continue
|
||||
|
||||
data_str = line[6:] # Remove "data: " prefix
|
||||
if data_str.strip() == "[DONE]":
|
||||
break
|
||||
|
||||
try:
|
||||
chunk_data = json.loads(data_str)
|
||||
|
||||
# Extract token from delta
|
||||
if chunk_data.get("choices") and len(chunk_data["choices"]) > 0:
|
||||
delta = chunk_data["choices"][0].get("delta", {})
|
||||
content = delta.get("content", "")
|
||||
|
||||
if content:
|
||||
full_response += content
|
||||
|
||||
# Publish token chunk to NATS streaming subject
|
||||
chunk_msg = {
|
||||
"request_id": request_id,
|
||||
"type": "chunk",
|
||||
"content": content,
|
||||
"done": False,
|
||||
}
|
||||
await self.nc.publish(
|
||||
f"{STREAM_RESPONSE_SUBJECT}.{request_id}",
|
||||
msgpack.packb(chunk_msg)
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
# Send completion message
|
||||
completion_msg = {
|
||||
"request_id": request_id,
|
||||
"type": "done",
|
||||
"content": "",
|
||||
"done": True,
|
||||
}
|
||||
await self.nc.publish(
|
||||
f"{STREAM_RESPONSE_SUBJECT}.{request_id}",
|
||||
msgpack.packb(completion_msg)
|
||||
)
|
||||
|
||||
logger.info(f"Streamed complete response ({len(full_response)} chars) for request {request_id}")
|
||||
|
||||
# Save messages to conversation history
|
||||
if session_id:
|
||||
await self.save_message_to_history(session_id, "user", query)
|
||||
await self.save_message_to_history(session_id, "assistant", full_response)
|
||||
|
||||
return full_response
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Streaming LLM generation failed: {e}")
|
||||
# Send error message
|
||||
error_msg = {
|
||||
"request_id": request_id,
|
||||
"type": "error",
|
||||
"content": "I'm sorry, I couldn't generate a response.",
|
||||
"done": True,
|
||||
"error": str(e),
|
||||
}
|
||||
await self.nc.publish(
|
||||
f"{STREAM_RESPONSE_SUBJECT}.{request_id}",
|
||||
msgpack.packb(error_msg)
|
||||
)
|
||||
return "I'm sorry, I couldn't generate a response."
|
||||
|
||||
async def synthesize_speech(self, text: str, language: str = "en") -> str:
|
||||
"""Convert text to speech using XTTS (Coqui TTS)."""
|
||||
try:
|
||||
# XTTS API endpoint - uses /api/tts for synthesis
|
||||
# The Coqui TTS server API accepts text and returns wav audio
|
||||
response = await self.http_client.get(
|
||||
f"{TTS_URL}/api/tts",
|
||||
params={
|
||||
"text": text,
|
||||
"language_id": language,
|
||||
# Optional: specify speaker_id for multi-speaker models
|
||||
},
|
||||
)
|
||||
if response.status_code == 200:
|
||||
audio_b64 = base64.b64encode(response.content).decode("utf-8")
|
||||
logger.info(f"Synthesized {len(response.content)} bytes of audio")
|
||||
return audio_b64
|
||||
else:
|
||||
logger.error(
|
||||
f"TTS returned status {response.status_code}: {response.text}"
|
||||
)
|
||||
return ""
|
||||
except Exception as e:
|
||||
logger.error(f"TTS failed: {e}")
|
||||
return ""
|
||||
|
||||
async def process_request(self, msg, is_premium=False):
|
||||
"""Process a voice assistant request."""
|
||||
start_time = time.time()
|
||||
span = None
|
||||
|
||||
# MLflow metrics tracking
|
||||
mlflow_metrics = None
|
||||
stt_start = None
|
||||
embedding_start = None
|
||||
rag_start = None
|
||||
rerank_start = None
|
||||
llm_start = None
|
||||
tts_start = None
|
||||
|
||||
try:
|
||||
data = msgpack.unpackb(msg.data, raw=False)
|
||||
request_id = data.get("request_id", "unknown")
|
||||
audio_b64 = data.get("audio_b64", "")
|
||||
user_id = data.get("user_id")
|
||||
|
||||
# Initialize MLflow metrics if available
|
||||
if self.mlflow_tracker and MLFLOW_AVAILABLE:
|
||||
mlflow_metrics = InferenceMetrics(
|
||||
request_id=request_id,
|
||||
user_id=user_id,
|
||||
session_id=data.get("session_id"),
|
||||
model_name=LLM_MODEL,
|
||||
model_endpoint=VLLM_URL,
|
||||
)
|
||||
|
||||
# Start tracing span
|
||||
if self.tracer:
|
||||
span = self.tracer.start_span("voice.process_request")
|
||||
span.set_attribute("request_id", request_id)
|
||||
span.set_attribute("user_id", user_id or "anonymous")
|
||||
span.set_attribute("premium", is_premium)
|
||||
|
||||
# Support both new parameters and backward compatibility with use_rag
|
||||
use_rag = data.get("use_rag") # Legacy parameter
|
||||
enable_rag = data.get(
|
||||
"enable_rag", use_rag if use_rag is not None else True
|
||||
)
|
||||
enable_reranker = data.get(
|
||||
"enable_reranker", use_rag if use_rag is not None else True
|
||||
)
|
||||
enable_streaming = data.get("enable_streaming", False) # New parameter for streaming
|
||||
|
||||
# Premium channel retrieves more documents for deeper RAG
|
||||
default_top_k = 15 if is_premium else 5
|
||||
top_k = data.get("top_k", default_top_k)
|
||||
|
||||
language = data.get("language", "en")
|
||||
session_id = data.get("session_id")
|
||||
|
||||
# Update MLflow metrics with request params
|
||||
if mlflow_metrics:
|
||||
mlflow_metrics.rag_enabled = enable_rag
|
||||
mlflow_metrics.reranker_enabled = enable_reranker
|
||||
mlflow_metrics.is_streaming = enable_streaming
|
||||
mlflow_metrics.is_premium = is_premium
|
||||
|
||||
# Add attributes to span
|
||||
if span:
|
||||
span.set_attribute("enable_rag", enable_rag)
|
||||
span.set_attribute("enable_reranker", enable_reranker)
|
||||
span.set_attribute("enable_streaming", enable_streaming)
|
||||
span.set_attribute("top_k", top_k)
|
||||
|
||||
logger.info(
|
||||
f"Processing {'premium ' if is_premium else ''}voice request {request_id} (RAG: {enable_rag}, Reranker: {enable_reranker}, top_k: {top_k})"
|
||||
)
|
||||
|
||||
# Warn if reranker is enabled without RAG
|
||||
if enable_reranker and not enable_rag:
|
||||
logger.warning(
|
||||
f"Request {request_id}: Reranker enabled without RAG - no documents to rerank"
|
||||
)
|
||||
|
||||
# Step 1: Transcribe audio
|
||||
stt_start = time.time()
|
||||
transcript = await self.transcribe(audio_b64)
|
||||
if mlflow_metrics:
|
||||
mlflow_metrics.stt_latency = time.time() - stt_start
|
||||
mlflow_metrics.prompt_length = len(transcript) if transcript else 0
|
||||
|
||||
if not transcript:
|
||||
if mlflow_metrics:
|
||||
mlflow_metrics.has_error = True
|
||||
mlflow_metrics.error_message = "Transcription failed"
|
||||
await self.publish_error(request_id, "Transcription failed")
|
||||
return
|
||||
|
||||
context = ""
|
||||
rag_sources = []
|
||||
docs = []
|
||||
|
||||
# Step 2: RAG retrieval (if enabled)
|
||||
if enable_rag and self.collection:
|
||||
# Get embeddings
|
||||
embedding_start = time.time()
|
||||
embeddings = await self.get_embeddings([transcript])
|
||||
if mlflow_metrics:
|
||||
mlflow_metrics.embedding_latency = time.time() - embedding_start
|
||||
|
||||
if embeddings:
|
||||
# Search Milvus with configurable top_k
|
||||
rag_start = time.time()
|
||||
docs = await self.search_milvus(embeddings[0], top_k=top_k)
|
||||
if mlflow_metrics:
|
||||
mlflow_metrics.rag_search_latency = time.time() - rag_start
|
||||
mlflow_metrics.rag_documents_retrieved = len(docs)
|
||||
if docs:
|
||||
rag_sources = [d.get("source", "") for d in docs]
|
||||
|
||||
# Step 3: Reranking (if enabled and we have documents)
|
||||
if enable_reranker and docs:
|
||||
# Rerank documents
|
||||
rerank_start = time.time()
|
||||
doc_texts = [d["text"] for d in docs]
|
||||
reranked = await self.rerank(transcript, doc_texts)
|
||||
if mlflow_metrics:
|
||||
mlflow_metrics.rerank_latency = time.time() - rerank_start
|
||||
# Take top 3 reranked documents with bounds checking
|
||||
sorted_docs = sorted(
|
||||
reranked, key=lambda x: x.get("relevance_score", 0), reverse=True
|
||||
)[:3]
|
||||
# Build context with bounds checking
|
||||
# Note: doc_texts and docs have the same length (doc_texts derived from docs)
|
||||
context_parts = []
|
||||
sources = []
|
||||
for item in sorted_docs:
|
||||
idx = item.get("index", -1)
|
||||
if 0 <= idx < len(docs):
|
||||
context_parts.append(doc_texts[idx])
|
||||
sources.append(docs[idx].get("source", ""))
|
||||
else:
|
||||
logger.warning(
|
||||
f"Reranker returned invalid index {idx} for {len(docs)} docs"
|
||||
)
|
||||
context = "\n\n".join(context_parts)
|
||||
rag_sources = sources
|
||||
elif docs:
|
||||
# Use documents without reranking (take top 3)
|
||||
doc_texts = [d["text"] for d in docs[:3]]
|
||||
context = "\n\n".join(doc_texts)
|
||||
rag_sources = [d.get("source", "") for d in docs[:3]]
|
||||
|
||||
# Step 4: Generate response (streaming or non-streaming)
|
||||
# Check for cached context window from Valkey (for attention offloading)
|
||||
cached_context = None
|
||||
if session_id:
|
||||
cached_context = await self.get_context_window(session_id)
|
||||
|
||||
# Combine RAG context with cached context if available
|
||||
if cached_context and context:
|
||||
# Prepend cached context to current RAG context
|
||||
combined_context = f"{cached_context}\n\n{context}"
|
||||
# Truncate to prevent unbounded growth
|
||||
if len(combined_context) > MAX_CONTEXT_LENGTH:
|
||||
logger.warning(f"Context length {len(combined_context)} exceeds max {MAX_CONTEXT_LENGTH}, truncating")
|
||||
# Keep the most recent context (from the end)
|
||||
combined_context = combined_context[-MAX_CONTEXT_LENGTH:]
|
||||
context = combined_context
|
||||
elif cached_context:
|
||||
# Only cached context, still need to check length
|
||||
if len(cached_context) > MAX_CONTEXT_LENGTH:
|
||||
logger.warning(f"Cached context length {len(cached_context)} exceeds max {MAX_CONTEXT_LENGTH}, truncating")
|
||||
cached_context = cached_context[-MAX_CONTEXT_LENGTH:]
|
||||
context = cached_context
|
||||
|
||||
# Save the combined context for future use (already truncated if needed)
|
||||
if session_id and context:
|
||||
await self.save_context_window(session_id, context)
|
||||
|
||||
# Track number of RAG docs used after reranking
|
||||
if mlflow_metrics and enable_rag:
|
||||
mlflow_metrics.rag_documents_used = min(3, len(docs)) if docs else 0
|
||||
|
||||
llm_start = time.time()
|
||||
if enable_streaming:
|
||||
# Use streaming response
|
||||
answer = await self.generate_response_streaming(transcript, context, request_id, session_id)
|
||||
else:
|
||||
# Use non-streaming response
|
||||
answer = await self.generate_response(transcript, context, session_id)
|
||||
|
||||
if mlflow_metrics:
|
||||
mlflow_metrics.llm_latency = time.time() - llm_start
|
||||
mlflow_metrics.response_length = len(answer)
|
||||
# Estimate token counts (rough approximation: 4 chars per token)
|
||||
mlflow_metrics.input_tokens = len(transcript) // 4
|
||||
mlflow_metrics.output_tokens = len(answer) // 4
|
||||
mlflow_metrics.total_tokens = mlflow_metrics.input_tokens + mlflow_metrics.output_tokens
|
||||
|
||||
# Step 5: Synthesize speech
|
||||
tts_start = time.time()
|
||||
audio_response = await self.synthesize_speech(answer, language)
|
||||
if mlflow_metrics:
|
||||
mlflow_metrics.tts_latency = time.time() - tts_start
|
||||
|
||||
# Publish result
|
||||
result = {
|
||||
"request_id": request_id,
|
||||
"user_id": user_id,
|
||||
"transcript": transcript,
|
||||
"response_text": answer,
|
||||
"audio_b64": audio_response,
|
||||
"used_rag": bool(context),
|
||||
"rag_enabled": enable_rag,
|
||||
"reranker_enabled": enable_reranker,
|
||||
"rag_sources": rag_sources,
|
||||
"success": True,
|
||||
}
|
||||
await self.nc.publish(
|
||||
f"{RESPONSE_SUBJECT}.{request_id}", msgpack.packb(result)
|
||||
)
|
||||
logger.info(f"Published response for request {request_id}")
|
||||
|
||||
# Record metrics
|
||||
duration = time.time() - start_time
|
||||
if self.request_counter:
|
||||
self.request_counter.add(1, {"premium": str(is_premium), "rag_enabled": str(enable_rag), "success": "true"})
|
||||
if self.request_duration:
|
||||
self.request_duration.record(duration, {"premium": str(is_premium), "rag_enabled": str(enable_rag)})
|
||||
if span:
|
||||
span.set_attribute("success", True)
|
||||
span.set_attribute("response_length", len(answer))
|
||||
span.set_attribute("transcript_length", len(transcript))
|
||||
|
||||
# Log to MLflow
|
||||
if self.mlflow_tracker and mlflow_metrics:
|
||||
mlflow_metrics.total_latency = duration
|
||||
await self.mlflow_tracker.log_inference(mlflow_metrics)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Request processing failed: {e}")
|
||||
if self.request_counter:
|
||||
self.request_counter.add(1, {"premium": str(is_premium), "success": "false"})
|
||||
if span:
|
||||
span.set_attribute("success", False)
|
||||
span.set_attribute("error", str(e))
|
||||
|
||||
# Log error to MLflow
|
||||
if self.mlflow_tracker and mlflow_metrics:
|
||||
mlflow_metrics.has_error = True
|
||||
mlflow_metrics.error_message = str(e)
|
||||
mlflow_metrics.total_latency = time.time() - start_time
|
||||
await self.mlflow_tracker.log_inference(mlflow_metrics)
|
||||
|
||||
await self.publish_error(data.get("request_id", "unknown"), str(e))
|
||||
finally:
|
||||
if span:
|
||||
span.end()
|
||||
|
||||
async def publish_error(self, request_id: str, error: str):
|
||||
"""Publish an error response."""
|
||||
result = {"request_id": request_id, "error": error, "success": False}
|
||||
await self.nc.publish(
|
||||
f"{RESPONSE_SUBJECT}.{request_id}", msgpack.packb(result)
|
||||
)
|
||||
|
||||
async def process_premium_request(self, msg):
|
||||
"""Process a premium voice request (wrapper for deeper RAG)."""
|
||||
await self.process_request(msg, is_premium=True)
|
||||
|
||||
async def run(self):
|
||||
"""Main run loop."""
|
||||
await self.setup()
|
||||
|
||||
# Subscribe to standard voice requests
|
||||
sub = await self.nc.subscribe(REQUEST_SUBJECT, cb=self.process_request)
|
||||
logger.info(f"Subscribed to {REQUEST_SUBJECT}")
|
||||
|
||||
# Subscribe to premium voice requests (deeper RAG retrieval)
|
||||
premium_sub = await self.nc.subscribe(
|
||||
PREMIUM_REQUEST_SUBJECT, cb=self.process_premium_request
|
||||
)
|
||||
logger.info(f"Subscribed to {PREMIUM_REQUEST_SUBJECT}")
|
||||
|
||||
# Handle shutdown
|
||||
def signal_handler():
|
||||
self.running = False
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
for sig in (signal.SIGTERM, signal.SIGINT):
|
||||
loop.add_signal_handler(sig, signal_handler)
|
||||
|
||||
# Keep running
|
||||
while self.running:
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Cleanup
|
||||
await sub.unsubscribe()
|
||||
await premium_sub.unsubscribe()
|
||||
await self.nc.close()
|
||||
if self.valkey_client:
|
||||
await self.valkey_client.close()
|
||||
if self.collection:
|
||||
connections.disconnect("default")
|
||||
if self.mlflow_tracker:
|
||||
await self.mlflow_tracker.stop()
|
||||
logger.info("Shutdown complete")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
assistant = VoiceAssistant()
|
||||
asyncio.run(assistant.run())
|
||||
234
voice_assistant_v2.py
Normal file
234
voice_assistant_v2.py
Normal file
@@ -0,0 +1,234 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Voice Assistant Service (Refactored)
|
||||
|
||||
End-to-end voice assistant pipeline using handler-base:
|
||||
1. Listen for audio on NATS subject "voice.request"
|
||||
2. Transcribe with Whisper (STT)
|
||||
3. Generate embeddings for RAG
|
||||
4. Retrieve context from Milvus
|
||||
5. Rerank with BGE reranker
|
||||
6. Generate response with vLLM
|
||||
7. Synthesize speech with XTTS
|
||||
8. Publish result to NATS "voice.response.{request_id}"
|
||||
"""
|
||||
import base64
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
from nats.aio.msg import Msg
|
||||
|
||||
from handler_base import Handler, Settings
|
||||
from handler_base.clients import (
|
||||
EmbeddingsClient,
|
||||
RerankerClient,
|
||||
LLMClient,
|
||||
TTSClient,
|
||||
STTClient,
|
||||
MilvusClient,
|
||||
)
|
||||
from handler_base.telemetry import create_span
|
||||
|
||||
logger = logging.getLogger("voice-assistant")
|
||||
|
||||
|
||||
class VoiceSettings(Settings):
|
||||
"""Voice assistant specific settings."""
|
||||
|
||||
service_name: str = "voice-assistant"
|
||||
|
||||
# RAG settings
|
||||
rag_top_k: int = 10
|
||||
rag_rerank_top_k: int = 5
|
||||
rag_collection: str = "documents"
|
||||
|
||||
# Audio settings
|
||||
stt_language: Optional[str] = None # Auto-detect
|
||||
tts_language: str = "en"
|
||||
|
||||
# Response settings
|
||||
include_transcription: bool = True
|
||||
include_sources: bool = False
|
||||
|
||||
|
||||
class VoiceAssistant(Handler):
|
||||
"""
|
||||
Voice request handler with full STT -> RAG -> LLM -> TTS pipeline.
|
||||
|
||||
Request format (msgpack):
|
||||
{
|
||||
"request_id": "uuid",
|
||||
"audio": "base64 encoded audio",
|
||||
"language": "optional language code",
|
||||
"collection": "optional collection name"
|
||||
}
|
||||
|
||||
Response format:
|
||||
{
|
||||
"request_id": "uuid",
|
||||
"transcription": "what the user said",
|
||||
"response": "generated text response",
|
||||
"audio": "base64 encoded response audio"
|
||||
}
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.voice_settings = VoiceSettings()
|
||||
super().__init__(
|
||||
subject="voice.request",
|
||||
settings=self.voice_settings,
|
||||
queue_group="voice-assistants",
|
||||
)
|
||||
|
||||
async def setup(self) -> None:
|
||||
"""Initialize service clients."""
|
||||
logger.info("Initializing voice assistant clients...")
|
||||
|
||||
self.stt = STTClient(self.voice_settings)
|
||||
self.embeddings = EmbeddingsClient(self.voice_settings)
|
||||
self.reranker = RerankerClient(self.voice_settings)
|
||||
self.llm = LLMClient(self.voice_settings)
|
||||
self.tts = TTSClient(self.voice_settings)
|
||||
self.milvus = MilvusClient(self.voice_settings)
|
||||
|
||||
await self.milvus.connect(self.voice_settings.rag_collection)
|
||||
|
||||
logger.info("Voice assistant clients initialized")
|
||||
|
||||
async def teardown(self) -> None:
|
||||
"""Clean up service clients."""
|
||||
logger.info("Closing voice assistant clients...")
|
||||
|
||||
await self.stt.close()
|
||||
await self.embeddings.close()
|
||||
await self.reranker.close()
|
||||
await self.llm.close()
|
||||
await self.tts.close()
|
||||
await self.milvus.close()
|
||||
|
||||
logger.info("Voice assistant clients closed")
|
||||
|
||||
async def handle_message(self, msg: Msg, data: Any) -> Optional[dict]:
|
||||
"""Handle incoming voice request."""
|
||||
request_id = data.get("request_id", "unknown")
|
||||
audio_b64 = data.get("audio", "")
|
||||
language = data.get("language", self.voice_settings.stt_language)
|
||||
collection = data.get("collection", self.voice_settings.rag_collection)
|
||||
|
||||
logger.info(f"Processing voice request {request_id}")
|
||||
|
||||
with create_span("voice.process") as span:
|
||||
if span:
|
||||
span.set_attribute("request.id", request_id)
|
||||
|
||||
# 1. Decode audio
|
||||
audio_bytes = base64.b64decode(audio_b64)
|
||||
|
||||
# 2. Transcribe audio to text
|
||||
transcription = await self._transcribe(audio_bytes, language)
|
||||
query = transcription.get("text", "")
|
||||
|
||||
if not query.strip():
|
||||
logger.warning(f"Empty transcription for request {request_id}")
|
||||
return {
|
||||
"request_id": request_id,
|
||||
"error": "Could not transcribe audio",
|
||||
}
|
||||
|
||||
logger.info(f"Transcribed: {query[:50]}...")
|
||||
|
||||
# 3. Generate query embedding
|
||||
embedding = await self._get_embedding(query)
|
||||
|
||||
# 4. Search Milvus for context
|
||||
documents = await self._search_context(embedding, collection)
|
||||
|
||||
# 5. Rerank documents
|
||||
reranked = await self._rerank_documents(query, documents)
|
||||
|
||||
# 6. Build context
|
||||
context = self._build_context(reranked)
|
||||
|
||||
# 7. Generate LLM response
|
||||
response_text = await self._generate_response(query, context)
|
||||
|
||||
# 8. Synthesize speech
|
||||
response_audio = await self._synthesize_speech(response_text)
|
||||
|
||||
# Build response
|
||||
result = {
|
||||
"request_id": request_id,
|
||||
"response": response_text,
|
||||
"audio": response_audio,
|
||||
}
|
||||
|
||||
if self.voice_settings.include_transcription:
|
||||
result["transcription"] = query
|
||||
|
||||
if self.voice_settings.include_sources:
|
||||
result["sources"] = [
|
||||
{"text": d["document"][:200], "score": d["score"]}
|
||||
for d in reranked[:3]
|
||||
]
|
||||
|
||||
logger.info(f"Completed voice request {request_id}")
|
||||
|
||||
# Publish to response subject
|
||||
response_subject = f"voice.response.{request_id}"
|
||||
await self.nats.publish(response_subject, result)
|
||||
|
||||
return result
|
||||
|
||||
async def _transcribe(
|
||||
self, audio: bytes, language: Optional[str]
|
||||
) -> dict:
|
||||
"""Transcribe audio to text."""
|
||||
with create_span("voice.stt"):
|
||||
return await self.stt.transcribe(audio, language=language)
|
||||
|
||||
async def _get_embedding(self, text: str) -> list[float]:
|
||||
"""Generate embedding for query text."""
|
||||
with create_span("voice.embedding"):
|
||||
return await self.embeddings.embed_single(text)
|
||||
|
||||
async def _search_context(
|
||||
self, embedding: list[float], collection: str
|
||||
) -> list[dict]:
|
||||
"""Search Milvus for relevant documents."""
|
||||
with create_span("voice.search"):
|
||||
return await self.milvus.search_with_texts(
|
||||
embedding,
|
||||
limit=self.voice_settings.rag_top_k,
|
||||
text_field="text",
|
||||
)
|
||||
|
||||
async def _rerank_documents(
|
||||
self, query: str, documents: list[dict]
|
||||
) -> list[dict]:
|
||||
"""Rerank documents by relevance."""
|
||||
with create_span("voice.rerank"):
|
||||
texts = [d.get("text", "") for d in documents]
|
||||
return await self.reranker.rerank(
|
||||
query, texts, top_k=self.voice_settings.rag_rerank_top_k
|
||||
)
|
||||
|
||||
def _build_context(self, documents: list[dict]) -> str:
|
||||
"""Build context string from ranked documents."""
|
||||
return "\n\n".join(d.get("document", "") for d in documents)
|
||||
|
||||
async def _generate_response(self, query: str, context: str) -> str:
|
||||
"""Generate LLM response."""
|
||||
with create_span("voice.generate"):
|
||||
return await self.llm.generate(query, context=context)
|
||||
|
||||
async def _synthesize_speech(self, text: str) -> str:
|
||||
"""Synthesize speech and return base64."""
|
||||
with create_span("voice.tts"):
|
||||
audio_bytes = await self.tts.synthesize(
|
||||
text, language=self.voice_settings.tts_language
|
||||
)
|
||||
return base64.b64encode(audio_bytes).decode()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
VoiceAssistant().run()
|
||||
Reference in New Issue
Block a user