fix: ruff formatting, allow-direct-references, and noqa for Kubeflow pipeline vars
All checks were successful
CI / Lint (push) Successful in 51s
CI / Test (push) Successful in 55s
CI / Release (push) Successful in 7s
CI / Notify (push) Successful in 2s

This commit is contained in:
2026-02-02 08:44:14 -05:00
parent 58465b77d8
commit 0462412353
5 changed files with 208 additions and 237 deletions

View File

@@ -14,70 +14,58 @@ from kfp import dsl
from kfp import compiler from kfp import compiler
@dsl.component( @dsl.component(base_image="python:3.13-slim", packages_to_install=["httpx"])
base_image="python:3.13-slim",
packages_to_install=["httpx"]
)
def transcribe_audio( def transcribe_audio(
audio_b64: str, audio_b64: str, whisper_url: str = "http://whisper-predictor.ai-ml.svc.cluster.local"
whisper_url: str = "http://whisper-predictor.ai-ml.svc.cluster.local"
) -> str: ) -> str:
"""Transcribe audio using Whisper STT service.""" """Transcribe audio using Whisper STT service."""
import base64 import base64
import httpx import httpx
audio_bytes = base64.b64decode(audio_b64) audio_bytes = base64.b64decode(audio_b64)
with httpx.Client(timeout=120.0) as client: with httpx.Client(timeout=120.0) as client:
response = client.post( response = client.post(
f"{whisper_url}/v1/audio/transcriptions", f"{whisper_url}/v1/audio/transcriptions",
files={"file": ("audio.wav", audio_bytes, "audio/wav")}, files={"file": ("audio.wav", audio_bytes, "audio/wav")},
data={"model": "whisper-large-v3", "language": "en"} data={"model": "whisper-large-v3", "language": "en"},
) )
result = response.json() result = response.json()
return result.get("text", "") return result.get("text", "")
@dsl.component( @dsl.component(base_image="python:3.13-slim", packages_to_install=["httpx"])
base_image="python:3.13-slim",
packages_to_install=["httpx"]
)
def generate_embeddings( def generate_embeddings(
text: str, text: str, embeddings_url: str = "http://embeddings-predictor.ai-ml.svc.cluster.local"
embeddings_url: str = "http://embeddings-predictor.ai-ml.svc.cluster.local"
) -> list: ) -> list:
"""Generate embeddings for RAG retrieval.""" """Generate embeddings for RAG retrieval."""
import httpx import httpx
with httpx.Client(timeout=60.0) as client: with httpx.Client(timeout=60.0) as client:
response = client.post( response = client.post(
f"{embeddings_url}/embeddings", f"{embeddings_url}/embeddings", json={"input": text, "model": "bge-small-en-v1.5"}
json={"input": text, "model": "bge-small-en-v1.5"}
) )
result = response.json() result = response.json()
return result["data"][0]["embedding"] return result["data"][0]["embedding"]
@dsl.component( @dsl.component(base_image="python:3.13-slim", packages_to_install=["pymilvus"])
base_image="python:3.13-slim",
packages_to_install=["pymilvus"]
)
def retrieve_context( def retrieve_context(
embedding: list, embedding: list,
milvus_host: str = "milvus.ai-ml.svc.cluster.local", milvus_host: str = "milvus.ai-ml.svc.cluster.local",
collection_name: str = "knowledge_base", collection_name: str = "knowledge_base",
top_k: int = 5 top_k: int = 5,
) -> list: ) -> list:
"""Retrieve relevant documents from Milvus vector database.""" """Retrieve relevant documents from Milvus vector database."""
from pymilvus import connections, Collection, utility from pymilvus import connections, Collection, utility
connections.connect(host=milvus_host, port=19530) connections.connect(host=milvus_host, port=19530)
if not utility.has_collection(collection_name): if not utility.has_collection(collection_name):
return [] return []
collection = Collection(collection_name) collection = Collection(collection_name)
collection.load() collection.load()
@@ -86,30 +74,29 @@ def retrieve_context(
anns_field="embedding", anns_field="embedding",
param={"metric_type": "COSINE", "params": {"nprobe": 10}}, param={"metric_type": "COSINE", "params": {"nprobe": 10}},
limit=top_k, limit=top_k,
output_fields=["text", "source"] output_fields=["text", "source"],
) )
documents = [] documents = []
for hits in results: for hits in results:
for hit in hits: for hit in hits:
documents.append({ documents.append(
"text": hit.entity.get("text"), {
"source": hit.entity.get("source"), "text": hit.entity.get("text"),
"score": hit.distance "source": hit.entity.get("source"),
}) "score": hit.distance,
}
)
return documents return documents
@dsl.component( @dsl.component(base_image="python:3.13-slim", packages_to_install=["httpx"])
base_image="python:3.13-slim",
packages_to_install=["httpx"]
)
def rerank_documents( def rerank_documents(
query: str, query: str,
documents: list, documents: list,
reranker_url: str = "http://reranker-predictor.ai-ml.svc.cluster.local", reranker_url: str = "http://reranker-predictor.ai-ml.svc.cluster.local",
top_k: int = 3 top_k: int = 3,
) -> list: ) -> list:
"""Rerank documents using BGE reranker.""" """Rerank documents using BGE reranker."""
import httpx import httpx
@@ -123,30 +110,25 @@ def rerank_documents(
json={ json={
"query": query, "query": query,
"documents": [doc["text"] for doc in documents], "documents": [doc["text"] for doc in documents],
"model": "bge-reranker-v2-m3" "model": "bge-reranker-v2-m3",
} },
) )
result = response.json() result = response.json()
# Sort by rerank score # Sort by rerank score
reranked = sorted( reranked = sorted(
zip(documents, result.get("scores", [0] * len(documents))), zip(documents, result.get("scores", [0] * len(documents))), key=lambda x: x[1], reverse=True
key=lambda x: x[1],
reverse=True
)[:top_k] )[:top_k]
return [doc for doc, score in reranked] return [doc for doc, score in reranked]
@dsl.component( @dsl.component(base_image="python:3.13-slim", packages_to_install=["httpx"])
base_image="python:3.13-slim",
packages_to_install=["httpx"]
)
def generate_response( def generate_response(
query: str, query: str,
context: list, context: list,
vllm_url: str = "http://llm-draft.ai-ml.svc.cluster.local:8000", vllm_url: str = "http://llm-draft.ai-ml.svc.cluster.local:8000",
model: str = "mistralai/Mistral-7B-Instruct-v0.3" model: str = "mistralai/Mistral-7B-Instruct-v0.3",
) -> str: ) -> str:
"""Generate response using vLLM.""" """Generate response using vLLM."""
import httpx import httpx
@@ -164,31 +146,22 @@ Keep responses concise and natural for speech synthesis."""
messages = [ messages = [
{"role": "system", "content": system_prompt}, {"role": "system", "content": system_prompt},
{"role": "user", "content": user_content} {"role": "user", "content": user_content},
] ]
with httpx.Client(timeout=180.0) as client: with httpx.Client(timeout=180.0) as client:
response = client.post( response = client.post(
f"{vllm_url}/v1/chat/completions", f"{vllm_url}/v1/chat/completions",
json={ json={"model": model, "messages": messages, "max_tokens": 512, "temperature": 0.7},
"model": model,
"messages": messages,
"max_tokens": 512,
"temperature": 0.7
}
) )
result = response.json() result = response.json()
return result["choices"][0]["message"]["content"] return result["choices"][0]["message"]["content"]
@dsl.component( @dsl.component(base_image="python:3.13-slim", packages_to_install=["httpx"])
base_image="python:3.13-slim",
packages_to_install=["httpx"]
)
def synthesize_speech( def synthesize_speech(
text: str, text: str, tts_url: str = "http://tts-predictor.ai-ml.svc.cluster.local"
tts_url: str = "http://tts-predictor.ai-ml.svc.cluster.local"
) -> str: ) -> str:
"""Convert text to speech using TTS service.""" """Convert text to speech using TTS service."""
import base64 import base64
@@ -197,11 +170,7 @@ def synthesize_speech(
with httpx.Client(timeout=120.0) as client: with httpx.Client(timeout=120.0) as client:
response = client.post( response = client.post(
f"{tts_url}/v1/audio/speech", f"{tts_url}/v1/audio/speech",
json={ json={"input": text, "voice": "en_US-lessac-high", "response_format": "wav"},
"input": text,
"voice": "en_US-lessac-high",
"response_format": "wav"
}
) )
audio_b64 = base64.b64encode(response.content).decode("utf-8") audio_b64 = base64.b64encode(response.content).decode("utf-8")
@@ -210,20 +179,17 @@ def synthesize_speech(
@dsl.pipeline( @dsl.pipeline(
name="voice-assistant-rag-pipeline", name="voice-assistant-rag-pipeline",
description="End-to-end voice assistant with RAG: STT -> Embeddings -> Milvus -> Rerank -> LLM -> TTS" description="End-to-end voice assistant with RAG: STT -> Embeddings -> Milvus -> Rerank -> LLM -> TTS",
) )
def voice_assistant_pipeline( def voice_assistant_pipeline(audio_b64: str, collection_name: str = "knowledge_base"):
audio_b64: str,
collection_name: str = "knowledge_base"
):
""" """
Voice Assistant Pipeline with RAG Voice Assistant Pipeline with RAG
Args: Args:
audio_b64: Base64-encoded audio file audio_b64: Base64-encoded audio file
collection_name: Milvus collection for RAG collection_name: Milvus collection for RAG
""" """
# Step 1: Transcribe audio with Whisper # Step 1: Transcribe audio with Whisper
transcribe_task = transcribe_audio(audio_b64=audio_b64) transcribe_task = transcribe_audio(audio_b64=audio_b64)
transcribe_task.set_caching_options(enable_caching=False) transcribe_task.set_caching_options(enable_caching=False)
@@ -233,70 +199,47 @@ def voice_assistant_pipeline(
embed_task.set_caching_options(enable_caching=True) embed_task.set_caching_options(enable_caching=True)
# Step 3: Retrieve context from Milvus # Step 3: Retrieve context from Milvus
retrieve_task = retrieve_context( retrieve_task = retrieve_context(embedding=embed_task.output, collection_name=collection_name)
embedding=embed_task.output,
collection_name=collection_name
)
# Step 4: Rerank documents # Step 4: Rerank documents
rerank_task = rerank_documents( rerank_task = rerank_documents(query=transcribe_task.output, documents=retrieve_task.output)
query=transcribe_task.output,
documents=retrieve_task.output
)
# Step 5: Generate response with context # Step 5: Generate response with context
llm_task = generate_response( llm_task = generate_response(query=transcribe_task.output, context=rerank_task.output)
query=transcribe_task.output,
context=rerank_task.output
)
# Step 6: Synthesize speech # Step 6: Synthesize speech
tts_task = synthesize_speech(text=llm_task.output) tts_task = synthesize_speech(text=llm_task.output) # noqa: F841
@dsl.pipeline( @dsl.pipeline(name="text-to-speech-pipeline", description="Simple text to speech pipeline")
name="text-to-speech-pipeline",
description="Simple text to speech pipeline"
)
def text_to_speech_pipeline(text: str): def text_to_speech_pipeline(text: str):
"""Simple TTS pipeline for testing.""" """Simple TTS pipeline for testing."""
tts_task = synthesize_speech(text=text) tts_task = synthesize_speech(text=text) # noqa: F841
@dsl.pipeline( @dsl.pipeline(
name="rag-query-pipeline", name="rag-query-pipeline", description="RAG query pipeline: Embed -> Retrieve -> Rerank -> LLM"
description="RAG query pipeline: Embed -> Retrieve -> Rerank -> LLM"
) )
def rag_query_pipeline( def rag_query_pipeline(query: str, collection_name: str = "knowledge_base"):
query: str,
collection_name: str = "knowledge_base"
):
""" """
RAG Query Pipeline (text input, no voice) RAG Query Pipeline (text input, no voice)
Args: Args:
query: Text query query: Text query
collection_name: Milvus collection name collection_name: Milvus collection name
""" """
# Embed the query # Embed the query
embed_task = generate_embeddings(text=query) embed_task = generate_embeddings(text=query)
# Retrieve from Milvus # Retrieve from Milvus
retrieve_task = retrieve_context( retrieve_task = retrieve_context(embedding=embed_task.output, collection_name=collection_name)
embedding=embed_task.output,
collection_name=collection_name
)
# Rerank # Rerank
rerank_task = rerank_documents( rerank_task = rerank_documents(query=query, documents=retrieve_task.output)
query=query,
documents=retrieve_task.output
)
# Generate response # Generate response
llm_task = generate_response( llm_task = generate_response( # noqa: F841
query=query, query=query, context=rerank_task.output
context=rerank_task.output
) )
@@ -307,10 +250,10 @@ if __name__ == "__main__":
("tts_pipeline.yaml", text_to_speech_pipeline), ("tts_pipeline.yaml", text_to_speech_pipeline),
("rag_pipeline.yaml", rag_query_pipeline), ("rag_pipeline.yaml", rag_query_pipeline),
] ]
for filename, pipeline_func in pipelines: for filename, pipeline_func in pipelines:
compiler.Compiler().compile(pipeline_func, filename) compiler.Compiler().compile(pipeline_func, filename)
print(f"Compiled: {filename}") print(f"Compiled: {filename}")
print("\nUpload these YAML files to Kubeflow Pipelines UI at:") print("\nUpload these YAML files to Kubeflow Pipelines UI at:")
print(" http://kubeflow.example.com/pipelines") print(" http://kubeflow.example.com/pipelines")

View File

@@ -22,6 +22,9 @@ dev = [
requires = ["hatchling"] requires = ["hatchling"]
build-backend = "hatchling.build" build-backend = "hatchling.build"
[tool.hatch.metadata]
allow-direct-references = true
[tool.hatch.build.targets.wheel] [tool.hatch.build.targets.wheel]
packages = ["."] packages = ["."]
only-include = ["voice_assistant.py"] only-include = ["voice_assistant.py"]

View File

@@ -1,11 +1,11 @@
""" """
Pytest configuration and fixtures for voice-assistant tests. Pytest configuration and fixtures for voice-assistant tests.
""" """
import asyncio import asyncio
import base64 import base64
import os import os
from typing import AsyncGenerator from unittest.mock import MagicMock, patch
from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
@@ -29,21 +29,54 @@ def event_loop():
def sample_audio_b64(): def sample_audio_b64():
"""Sample base64 encoded audio for testing.""" """Sample base64 encoded audio for testing."""
# 16-bit PCM silence (44 bytes header + 1000 samples) # 16-bit PCM silence (44 bytes header + 1000 samples)
wav_header = bytes([ wav_header = bytes(
0x52, 0x49, 0x46, 0x46, # "RIFF" [
0x24, 0x08, 0x00, 0x00, # File size 0x52,
0x57, 0x41, 0x56, 0x45, # "WAVE" 0x49,
0x66, 0x6D, 0x74, 0x20, # "fmt " 0x46,
0x10, 0x00, 0x00, 0x00, # Chunk size 0x46, # "RIFF"
0x01, 0x00, # PCM format 0x24,
0x01, 0x00, # Mono 0x08,
0x80, 0x3E, 0x00, 0x00, # Sample rate (16000) 0x00,
0x00, 0x7D, 0x00, 0x00, # Byte rate 0x00, # File size
0x02, 0x00, # Block align 0x57,
0x10, 0x00, # Bits per sample 0x41,
0x64, 0x61, 0x74, 0x61, # "data" 0x56,
0x00, 0x08, 0x00, 0x00, # Data size 0x45, # "WAVE"
]) 0x66,
0x6D,
0x74,
0x20, # "fmt "
0x10,
0x00,
0x00,
0x00, # Chunk size
0x01,
0x00, # PCM format
0x01,
0x00, # Mono
0x80,
0x3E,
0x00,
0x00, # Sample rate (16000)
0x00,
0x7D,
0x00,
0x00, # Byte rate
0x02,
0x00, # Block align
0x10,
0x00, # Bits per sample
0x64,
0x61,
0x74,
0x61, # "data"
0x00,
0x08,
0x00,
0x00, # Data size
]
)
silence = bytes([0x00] * 2048) silence = bytes([0x00] * 2048)
return base64.b64encode(wav_header + silence).decode() return base64.b64encode(wav_header + silence).decode()
@@ -96,13 +129,14 @@ def mock_voice_request(sample_audio_b64):
@pytest.fixture @pytest.fixture
def mock_clients(): def mock_clients():
"""Mock all service clients.""" """Mock all service clients."""
with patch("voice_assistant.STTClient") as stt, \ with (
patch("voice_assistant.EmbeddingsClient") as embeddings, \ patch("voice_assistant.STTClient") as stt,
patch("voice_assistant.RerankerClient") as reranker, \ patch("voice_assistant.EmbeddingsClient") as embeddings,
patch("voice_assistant.LLMClient") as llm, \ patch("voice_assistant.RerankerClient") as reranker,
patch("voice_assistant.TTSClient") as tts, \ patch("voice_assistant.LLMClient") as llm,
patch("voice_assistant.MilvusClient") as milvus: patch("voice_assistant.TTSClient") as tts,
patch("voice_assistant.MilvusClient") as milvus,
):
yield { yield {
"stt": stt, "stt": stt,
"embeddings": embeddings, "embeddings": embeddings,

View File

@@ -1,9 +1,9 @@
""" """
Unit tests for VoiceAssistant handler. Unit tests for VoiceAssistant handler.
""" """
import base64
import pytest import pytest
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, patch
# Import after environment is set up in conftest # Import after environment is set up in conftest
from voice_assistant import VoiceAssistant, VoiceSettings from voice_assistant import VoiceAssistant, VoiceSettings
@@ -11,11 +11,11 @@ from voice_assistant import VoiceAssistant, VoiceSettings
class TestVoiceSettings: class TestVoiceSettings:
"""Tests for VoiceSettings configuration.""" """Tests for VoiceSettings configuration."""
def test_default_settings(self): def test_default_settings(self):
"""Test default settings values.""" """Test default settings values."""
settings = VoiceSettings() settings = VoiceSettings()
assert settings.service_name == "voice-assistant" assert settings.service_name == "voice-assistant"
assert settings.rag_top_k == 10 assert settings.rag_top_k == 10
assert settings.rag_rerank_top_k == 5 assert settings.rag_rerank_top_k == 5
@@ -24,37 +24,35 @@ class TestVoiceSettings:
assert settings.tts_language == "en" assert settings.tts_language == "en"
assert settings.include_transcription is True assert settings.include_transcription is True
assert settings.include_sources is False assert settings.include_sources is False
def test_custom_settings(self, monkeypatch): def test_custom_settings(self, monkeypatch):
"""Test settings from environment.""" """Test settings from environment."""
monkeypatch.setenv("RAG_TOP_K", "20") monkeypatch.setenv("RAG_TOP_K", "20")
monkeypatch.setenv("RAG_COLLECTION", "custom_collection") monkeypatch.setenv("RAG_COLLECTION", "custom_collection")
# Note: Would need to re-instantiate settings to pick up env vars # Note: Would need to re-instantiate settings to pick up env vars
settings = VoiceSettings( settings = VoiceSettings(rag_top_k=20, rag_collection="custom_collection")
rag_top_k=20,
rag_collection="custom_collection"
)
assert settings.rag_top_k == 20 assert settings.rag_top_k == 20
assert settings.rag_collection == "custom_collection" assert settings.rag_collection == "custom_collection"
class TestVoiceAssistant: class TestVoiceAssistant:
"""Tests for VoiceAssistant handler.""" """Tests for VoiceAssistant handler."""
@pytest.fixture @pytest.fixture
def handler(self): def handler(self):
"""Create handler with mocked clients.""" """Create handler with mocked clients."""
with patch("voice_assistant.STTClient"), \ with (
patch("voice_assistant.EmbeddingsClient"), \ patch("voice_assistant.STTClient"),
patch("voice_assistant.RerankerClient"), \ patch("voice_assistant.EmbeddingsClient"),
patch("voice_assistant.LLMClient"), \ patch("voice_assistant.RerankerClient"),
patch("voice_assistant.TTSClient"), \ patch("voice_assistant.LLMClient"),
patch("voice_assistant.MilvusClient"): patch("voice_assistant.TTSClient"),
patch("voice_assistant.MilvusClient"),
):
handler = VoiceAssistant() handler = VoiceAssistant()
# Setup mock clients # Setup mock clients
handler.stt = AsyncMock() handler.stt = AsyncMock()
handler.embeddings = AsyncMock() handler.embeddings = AsyncMock()
@@ -63,15 +61,15 @@ class TestVoiceAssistant:
handler.tts = AsyncMock() handler.tts = AsyncMock()
handler.milvus = AsyncMock() handler.milvus = AsyncMock()
handler.nats = AsyncMock() handler.nats = AsyncMock()
yield handler yield handler
def test_init(self, handler): def test_init(self, handler):
"""Test handler initialization.""" """Test handler initialization."""
assert handler.subject == "voice.request" assert handler.subject == "voice.request"
assert handler.queue_group == "voice-assistants" assert handler.queue_group == "voice-assistants"
assert handler.voice_settings.service_name == "voice-assistant" assert handler.voice_settings.service_name == "voice-assistant"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_handle_message_success( async def test_handle_message_success(
self, self,
@@ -90,16 +88,16 @@ class TestVoiceAssistant:
handler.reranker.rerank.return_value = sample_reranked handler.reranker.rerank.return_value = sample_reranked
handler.llm.generate.return_value = "Machine learning is a type of AI." handler.llm.generate.return_value = "Machine learning is a type of AI."
handler.tts.synthesize.return_value = b"audio_bytes" handler.tts.synthesize.return_value = b"audio_bytes"
# Execute # Execute
result = await handler.handle_message(mock_nats_message, mock_voice_request) result = await handler.handle_message(mock_nats_message, mock_voice_request)
# Verify # Verify
assert result["request_id"] == "test-request-123" assert result["request_id"] == "test-request-123"
assert result["response"] == "Machine learning is a type of AI." assert result["response"] == "Machine learning is a type of AI."
assert "audio" in result assert "audio" in result
assert result["transcription"] == "What is machine learning?" assert result["transcription"] == "What is machine learning?"
# Verify pipeline was called # Verify pipeline was called
handler.stt.transcribe.assert_called_once() handler.stt.transcribe.assert_called_once()
handler.embeddings.embed_single.assert_called_once() handler.embeddings.embed_single.assert_called_once()
@@ -107,7 +105,7 @@ class TestVoiceAssistant:
handler.reranker.rerank.assert_called_once() handler.reranker.rerank.assert_called_once()
handler.llm.generate.assert_called_once() handler.llm.generate.assert_called_once()
handler.tts.synthesize.assert_called_once() handler.tts.synthesize.assert_called_once()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_handle_message_empty_transcription( async def test_handle_message_empty_transcription(
self, self,
@@ -117,15 +115,15 @@ class TestVoiceAssistant:
): ):
"""Test handling when transcription is empty.""" """Test handling when transcription is empty."""
handler.stt.transcribe.return_value = {"text": ""} handler.stt.transcribe.return_value = {"text": ""}
result = await handler.handle_message(mock_nats_message, mock_voice_request) result = await handler.handle_message(mock_nats_message, mock_voice_request)
assert "error" in result assert "error" in result
assert result["error"] == "Could not transcribe audio" assert result["error"] == "Could not transcribe audio"
# Verify pipeline stopped after transcription # Verify pipeline stopped after transcription
handler.embeddings.embed_single.assert_not_called() handler.embeddings.embed_single.assert_not_called()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_handle_message_with_sources( async def test_handle_message_with_sources(
self, self,
@@ -138,7 +136,7 @@ class TestVoiceAssistant:
): ):
"""Test response includes sources when enabled.""" """Test response includes sources when enabled."""
handler.voice_settings.include_sources = True handler.voice_settings.include_sources = True
# Setup mocks # Setup mocks
handler.stt.transcribe.return_value = {"text": "Hello"} handler.stt.transcribe.return_value = {"text": "Hello"}
handler.embeddings.embed_single.return_value = sample_embedding handler.embeddings.embed_single.return_value = sample_embedding
@@ -146,51 +144,52 @@ class TestVoiceAssistant:
handler.reranker.rerank.return_value = sample_reranked handler.reranker.rerank.return_value = sample_reranked
handler.llm.generate.return_value = "Hi there!" handler.llm.generate.return_value = "Hi there!"
handler.tts.synthesize.return_value = b"audio" handler.tts.synthesize.return_value = b"audio"
result = await handler.handle_message(mock_nats_message, mock_voice_request) result = await handler.handle_message(mock_nats_message, mock_voice_request)
assert "sources" in result assert "sources" in result
assert len(result["sources"]) <= 3 assert len(result["sources"]) <= 3
def test_build_context(self, handler): def test_build_context(self, handler):
"""Test context building from documents.""" """Test context building from documents."""
documents = [ documents = [
{"document": "First doc content"}, {"document": "First doc content"},
{"document": "Second doc content"}, {"document": "Second doc content"},
] ]
context = handler._build_context(documents) context = handler._build_context(documents)
assert "First doc content" in context assert "First doc content" in context
assert "Second doc content" in context assert "Second doc content" in context
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_setup_initializes_clients(self): async def test_setup_initializes_clients(self):
"""Test that setup initializes all clients.""" """Test that setup initializes all clients."""
with patch("voice_assistant.STTClient") as stt_cls, \ with (
patch("voice_assistant.EmbeddingsClient") as emb_cls, \ patch("voice_assistant.STTClient") as stt_cls,
patch("voice_assistant.RerankerClient") as rer_cls, \ patch("voice_assistant.EmbeddingsClient") as emb_cls,
patch("voice_assistant.LLMClient") as llm_cls, \ patch("voice_assistant.RerankerClient") as rer_cls,
patch("voice_assistant.TTSClient") as tts_cls, \ patch("voice_assistant.LLMClient") as llm_cls,
patch("voice_assistant.MilvusClient") as mil_cls: patch("voice_assistant.TTSClient") as tts_cls,
patch("voice_assistant.MilvusClient") as mil_cls,
):
mil_cls.return_value.connect = AsyncMock() mil_cls.return_value.connect = AsyncMock()
handler = VoiceAssistant() handler = VoiceAssistant()
await handler.setup() await handler.setup()
stt_cls.assert_called_once() stt_cls.assert_called_once()
emb_cls.assert_called_once() emb_cls.assert_called_once()
rer_cls.assert_called_once() rer_cls.assert_called_once()
llm_cls.assert_called_once() llm_cls.assert_called_once()
tts_cls.assert_called_once() tts_cls.assert_called_once()
mil_cls.assert_called_once() mil_cls.assert_called_once()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_teardown_closes_clients(self, handler): async def test_teardown_closes_clients(self, handler):
"""Test that teardown closes all clients.""" """Test that teardown closes all clients."""
await handler.teardown() await handler.teardown()
handler.stt.close.assert_called_once() handler.stt.close.assert_called_once()
handler.embeddings.close.assert_called_once() handler.embeddings.close.assert_called_once()
handler.reranker.close.assert_called_once() handler.reranker.close.assert_called_once()

View File

@@ -12,6 +12,7 @@ End-to-end voice assistant pipeline using handler-base:
7. Synthesize speech with XTTS 7. Synthesize speech with XTTS
8. Publish result to NATS "voice.response.{request_id}" 8. Publish result to NATS "voice.response.{request_id}"
""" """
import base64 import base64
import logging import logging
from typing import Any, Optional from typing import Any, Optional
@@ -34,18 +35,18 @@ logger = logging.getLogger("voice-assistant")
class VoiceSettings(Settings): class VoiceSettings(Settings):
"""Voice assistant specific settings.""" """Voice assistant specific settings."""
service_name: str = "voice-assistant" service_name: str = "voice-assistant"
# RAG settings # RAG settings
rag_top_k: int = 10 rag_top_k: int = 10
rag_rerank_top_k: int = 5 rag_rerank_top_k: int = 5
rag_collection: str = "documents" rag_collection: str = "documents"
# Audio settings # Audio settings
stt_language: Optional[str] = None # Auto-detect stt_language: Optional[str] = None # Auto-detect
tts_language: str = "en" tts_language: str = "en"
# Response settings # Response settings
include_transcription: bool = True include_transcription: bool = True
include_sources: bool = False include_sources: bool = False
@@ -54,7 +55,7 @@ class VoiceSettings(Settings):
class VoiceAssistant(Handler): class VoiceAssistant(Handler):
""" """
Voice request handler with full STT -> RAG -> LLM -> TTS pipeline. Voice request handler with full STT -> RAG -> LLM -> TTS pipeline.
Request format (msgpack): Request format (msgpack):
{ {
"request_id": "uuid", "request_id": "uuid",
@@ -62,7 +63,7 @@ class VoiceAssistant(Handler):
"language": "optional language code", "language": "optional language code",
"collection": "optional collection name" "collection": "optional collection name"
} }
Response format: Response format:
{ {
"request_id": "uuid", "request_id": "uuid",
@@ -71,7 +72,7 @@ class VoiceAssistant(Handler):
"audio": "base64 encoded response audio" "audio": "base64 encoded response audio"
} }
""" """
def __init__(self): def __init__(self):
self.voice_settings = VoiceSettings() self.voice_settings = VoiceSettings()
super().__init__( super().__init__(
@@ -79,121 +80,116 @@ class VoiceAssistant(Handler):
settings=self.voice_settings, settings=self.voice_settings,
queue_group="voice-assistants", queue_group="voice-assistants",
) )
async def setup(self) -> None: async def setup(self) -> None:
"""Initialize service clients.""" """Initialize service clients."""
logger.info("Initializing voice assistant clients...") logger.info("Initializing voice assistant clients...")
self.stt = STTClient(self.voice_settings) self.stt = STTClient(self.voice_settings)
self.embeddings = EmbeddingsClient(self.voice_settings) self.embeddings = EmbeddingsClient(self.voice_settings)
self.reranker = RerankerClient(self.voice_settings) self.reranker = RerankerClient(self.voice_settings)
self.llm = LLMClient(self.voice_settings) self.llm = LLMClient(self.voice_settings)
self.tts = TTSClient(self.voice_settings) self.tts = TTSClient(self.voice_settings)
self.milvus = MilvusClient(self.voice_settings) self.milvus = MilvusClient(self.voice_settings)
await self.milvus.connect(self.voice_settings.rag_collection) await self.milvus.connect(self.voice_settings.rag_collection)
logger.info("Voice assistant clients initialized") logger.info("Voice assistant clients initialized")
async def teardown(self) -> None: async def teardown(self) -> None:
"""Clean up service clients.""" """Clean up service clients."""
logger.info("Closing voice assistant clients...") logger.info("Closing voice assistant clients...")
await self.stt.close() await self.stt.close()
await self.embeddings.close() await self.embeddings.close()
await self.reranker.close() await self.reranker.close()
await self.llm.close() await self.llm.close()
await self.tts.close() await self.tts.close()
await self.milvus.close() await self.milvus.close()
logger.info("Voice assistant clients closed") logger.info("Voice assistant clients closed")
async def handle_message(self, msg: Msg, data: Any) -> Optional[dict]: async def handle_message(self, msg: Msg, data: Any) -> Optional[dict]:
"""Handle incoming voice request.""" """Handle incoming voice request."""
request_id = data.get("request_id", "unknown") request_id = data.get("request_id", "unknown")
audio_b64 = data.get("audio", "") audio_b64 = data.get("audio", "")
language = data.get("language", self.voice_settings.stt_language) language = data.get("language", self.voice_settings.stt_language)
collection = data.get("collection", self.voice_settings.rag_collection) collection = data.get("collection", self.voice_settings.rag_collection)
logger.info(f"Processing voice request {request_id}") logger.info(f"Processing voice request {request_id}")
with create_span("voice.process") as span: with create_span("voice.process") as span:
if span: if span:
span.set_attribute("request.id", request_id) span.set_attribute("request.id", request_id)
# 1. Decode audio # 1. Decode audio
audio_bytes = base64.b64decode(audio_b64) audio_bytes = base64.b64decode(audio_b64)
# 2. Transcribe audio to text # 2. Transcribe audio to text
transcription = await self._transcribe(audio_bytes, language) transcription = await self._transcribe(audio_bytes, language)
query = transcription.get("text", "") query = transcription.get("text", "")
if not query.strip(): if not query.strip():
logger.warning(f"Empty transcription for request {request_id}") logger.warning(f"Empty transcription for request {request_id}")
return { return {
"request_id": request_id, "request_id": request_id,
"error": "Could not transcribe audio", "error": "Could not transcribe audio",
} }
logger.info(f"Transcribed: {query[:50]}...") logger.info(f"Transcribed: {query[:50]}...")
# 3. Generate query embedding # 3. Generate query embedding
embedding = await self._get_embedding(query) embedding = await self._get_embedding(query)
# 4. Search Milvus for context # 4. Search Milvus for context
documents = await self._search_context(embedding, collection) documents = await self._search_context(embedding, collection)
# 5. Rerank documents # 5. Rerank documents
reranked = await self._rerank_documents(query, documents) reranked = await self._rerank_documents(query, documents)
# 6. Build context # 6. Build context
context = self._build_context(reranked) context = self._build_context(reranked)
# 7. Generate LLM response # 7. Generate LLM response
response_text = await self._generate_response(query, context) response_text = await self._generate_response(query, context)
# 8. Synthesize speech # 8. Synthesize speech
response_audio = await self._synthesize_speech(response_text) response_audio = await self._synthesize_speech(response_text)
# Build response # Build response
result = { result = {
"request_id": request_id, "request_id": request_id,
"response": response_text, "response": response_text,
"audio": response_audio, "audio": response_audio,
} }
if self.voice_settings.include_transcription: if self.voice_settings.include_transcription:
result["transcription"] = query result["transcription"] = query
if self.voice_settings.include_sources: if self.voice_settings.include_sources:
result["sources"] = [ result["sources"] = [
{"text": d["document"][:200], "score": d["score"]} {"text": d["document"][:200], "score": d["score"]} for d in reranked[:3]
for d in reranked[:3]
] ]
logger.info(f"Completed voice request {request_id}") logger.info(f"Completed voice request {request_id}")
# Publish to response subject # Publish to response subject
response_subject = f"voice.response.{request_id}" response_subject = f"voice.response.{request_id}"
await self.nats.publish(response_subject, result) await self.nats.publish(response_subject, result)
return result return result
async def _transcribe( async def _transcribe(self, audio: bytes, language: Optional[str]) -> dict:
self, audio: bytes, language: Optional[str]
) -> dict:
"""Transcribe audio to text.""" """Transcribe audio to text."""
with create_span("voice.stt"): with create_span("voice.stt"):
return await self.stt.transcribe(audio, language=language) return await self.stt.transcribe(audio, language=language)
async def _get_embedding(self, text: str) -> list[float]: async def _get_embedding(self, text: str) -> list[float]:
"""Generate embedding for query text.""" """Generate embedding for query text."""
with create_span("voice.embedding"): with create_span("voice.embedding"):
return await self.embeddings.embed_single(text) return await self.embeddings.embed_single(text)
async def _search_context( async def _search_context(self, embedding: list[float], collection: str) -> list[dict]:
self, embedding: list[float], collection: str
) -> list[dict]:
"""Search Milvus for relevant documents.""" """Search Milvus for relevant documents."""
with create_span("voice.search"): with create_span("voice.search"):
return await self.milvus.search_with_texts( return await self.milvus.search_with_texts(
@@ -201,32 +197,28 @@ class VoiceAssistant(Handler):
limit=self.voice_settings.rag_top_k, limit=self.voice_settings.rag_top_k,
text_field="text", text_field="text",
) )
async def _rerank_documents( async def _rerank_documents(self, query: str, documents: list[dict]) -> list[dict]:
self, query: str, documents: list[dict]
) -> list[dict]:
"""Rerank documents by relevance.""" """Rerank documents by relevance."""
with create_span("voice.rerank"): with create_span("voice.rerank"):
texts = [d.get("text", "") for d in documents] texts = [d.get("text", "") for d in documents]
return await self.reranker.rerank( return await self.reranker.rerank(
query, texts, top_k=self.voice_settings.rag_rerank_top_k query, texts, top_k=self.voice_settings.rag_rerank_top_k
) )
def _build_context(self, documents: list[dict]) -> str: def _build_context(self, documents: list[dict]) -> str:
"""Build context string from ranked documents.""" """Build context string from ranked documents."""
return "\n\n".join(d.get("document", "") for d in documents) return "\n\n".join(d.get("document", "") for d in documents)
async def _generate_response(self, query: str, context: str) -> str: async def _generate_response(self, query: str, context: str) -> str:
"""Generate LLM response.""" """Generate LLM response."""
with create_span("voice.generate"): with create_span("voice.generate"):
return await self.llm.generate(query, context=context) return await self.llm.generate(query, context=context)
async def _synthesize_speech(self, text: str) -> str: async def _synthesize_speech(self, text: str) -> str:
"""Synthesize speech and return base64.""" """Synthesize speech and return base64."""
with create_span("voice.tts"): with create_span("voice.tts"):
audio_bytes = await self.tts.synthesize( audio_bytes = await self.tts.synthesize(text, language=self.voice_settings.tts_language)
text, language=self.voice_settings.tts_language
)
return base64.b64encode(audio_bytes).decode() return base64.b64encode(audio_bytes).decode()