From 04624123536a25266712f0b283d01751e619547d Mon Sep 17 00:00:00 2001 From: "Billy D." Date: Mon, 2 Feb 2026 08:44:14 -0500 Subject: [PATCH] fix: ruff formatting, allow-direct-references, and noqa for Kubeflow pipeline vars --- pipelines/voice_pipeline.py | 163 +++++++++++----------------------- pyproject.toml | 3 + tests/conftest.py | 82 ++++++++++++----- tests/test_voice_assistant.py | 97 ++++++++++---------- voice_assistant.py | 100 ++++++++++----------- 5 files changed, 208 insertions(+), 237 deletions(-) diff --git a/pipelines/voice_pipeline.py b/pipelines/voice_pipeline.py index 886f6c0..c3962d1 100644 --- a/pipelines/voice_pipeline.py +++ b/pipelines/voice_pipeline.py @@ -14,70 +14,58 @@ from kfp import dsl from kfp import compiler -@dsl.component( - base_image="python:3.13-slim", - packages_to_install=["httpx"] -) +@dsl.component(base_image="python:3.13-slim", packages_to_install=["httpx"]) def transcribe_audio( - audio_b64: str, - whisper_url: str = "http://whisper-predictor.ai-ml.svc.cluster.local" + audio_b64: str, whisper_url: str = "http://whisper-predictor.ai-ml.svc.cluster.local" ) -> str: """Transcribe audio using Whisper STT service.""" import base64 import httpx audio_bytes = base64.b64decode(audio_b64) - + with httpx.Client(timeout=120.0) as client: response = client.post( f"{whisper_url}/v1/audio/transcriptions", files={"file": ("audio.wav", audio_bytes, "audio/wav")}, - data={"model": "whisper-large-v3", "language": "en"} + data={"model": "whisper-large-v3", "language": "en"}, ) result = response.json() return result.get("text", "") -@dsl.component( - base_image="python:3.13-slim", - packages_to_install=["httpx"] -) +@dsl.component(base_image="python:3.13-slim", packages_to_install=["httpx"]) def generate_embeddings( - text: str, - embeddings_url: str = "http://embeddings-predictor.ai-ml.svc.cluster.local" + text: str, embeddings_url: str = "http://embeddings-predictor.ai-ml.svc.cluster.local" ) -> list: """Generate embeddings for RAG retrieval.""" import httpx with httpx.Client(timeout=60.0) as client: response = client.post( - f"{embeddings_url}/embeddings", - json={"input": text, "model": "bge-small-en-v1.5"} + f"{embeddings_url}/embeddings", json={"input": text, "model": "bge-small-en-v1.5"} ) result = response.json() return result["data"][0]["embedding"] -@dsl.component( - base_image="python:3.13-slim", - packages_to_install=["pymilvus"] -) +@dsl.component(base_image="python:3.13-slim", packages_to_install=["pymilvus"]) def retrieve_context( embedding: list, milvus_host: str = "milvus.ai-ml.svc.cluster.local", collection_name: str = "knowledge_base", - top_k: int = 5 + top_k: int = 5, ) -> list: """Retrieve relevant documents from Milvus vector database.""" from pymilvus import connections, Collection, utility connections.connect(host=milvus_host, port=19530) - + if not utility.has_collection(collection_name): return [] - + collection = Collection(collection_name) collection.load() @@ -86,30 +74,29 @@ def retrieve_context( anns_field="embedding", param={"metric_type": "COSINE", "params": {"nprobe": 10}}, limit=top_k, - output_fields=["text", "source"] + output_fields=["text", "source"], ) documents = [] for hits in results: for hit in hits: - documents.append({ - "text": hit.entity.get("text"), - "source": hit.entity.get("source"), - "score": hit.distance - }) + documents.append( + { + "text": hit.entity.get("text"), + "source": hit.entity.get("source"), + "score": hit.distance, + } + ) return documents -@dsl.component( - base_image="python:3.13-slim", - packages_to_install=["httpx"] -) +@dsl.component(base_image="python:3.13-slim", packages_to_install=["httpx"]) def rerank_documents( query: str, documents: list, reranker_url: str = "http://reranker-predictor.ai-ml.svc.cluster.local", - top_k: int = 3 + top_k: int = 3, ) -> list: """Rerank documents using BGE reranker.""" import httpx @@ -123,30 +110,25 @@ def rerank_documents( json={ "query": query, "documents": [doc["text"] for doc in documents], - "model": "bge-reranker-v2-m3" - } + "model": "bge-reranker-v2-m3", + }, ) result = response.json() # Sort by rerank score reranked = sorted( - zip(documents, result.get("scores", [0] * len(documents))), - key=lambda x: x[1], - reverse=True + zip(documents, result.get("scores", [0] * len(documents))), key=lambda x: x[1], reverse=True )[:top_k] return [doc for doc, score in reranked] -@dsl.component( - base_image="python:3.13-slim", - packages_to_install=["httpx"] -) +@dsl.component(base_image="python:3.13-slim", packages_to_install=["httpx"]) def generate_response( query: str, context: list, vllm_url: str = "http://llm-draft.ai-ml.svc.cluster.local:8000", - model: str = "mistralai/Mistral-7B-Instruct-v0.3" + model: str = "mistralai/Mistral-7B-Instruct-v0.3", ) -> str: """Generate response using vLLM.""" import httpx @@ -164,31 +146,22 @@ Keep responses concise and natural for speech synthesis.""" messages = [ {"role": "system", "content": system_prompt}, - {"role": "user", "content": user_content} + {"role": "user", "content": user_content}, ] with httpx.Client(timeout=180.0) as client: response = client.post( f"{vllm_url}/v1/chat/completions", - json={ - "model": model, - "messages": messages, - "max_tokens": 512, - "temperature": 0.7 - } + json={"model": model, "messages": messages, "max_tokens": 512, "temperature": 0.7}, ) result = response.json() return result["choices"][0]["message"]["content"] -@dsl.component( - base_image="python:3.13-slim", - packages_to_install=["httpx"] -) +@dsl.component(base_image="python:3.13-slim", packages_to_install=["httpx"]) def synthesize_speech( - text: str, - tts_url: str = "http://tts-predictor.ai-ml.svc.cluster.local" + text: str, tts_url: str = "http://tts-predictor.ai-ml.svc.cluster.local" ) -> str: """Convert text to speech using TTS service.""" import base64 @@ -197,11 +170,7 @@ def synthesize_speech( with httpx.Client(timeout=120.0) as client: response = client.post( f"{tts_url}/v1/audio/speech", - json={ - "input": text, - "voice": "en_US-lessac-high", - "response_format": "wav" - } + json={"input": text, "voice": "en_US-lessac-high", "response_format": "wav"}, ) audio_b64 = base64.b64encode(response.content).decode("utf-8") @@ -210,20 +179,17 @@ def synthesize_speech( @dsl.pipeline( name="voice-assistant-rag-pipeline", - description="End-to-end voice assistant with RAG: STT -> Embeddings -> Milvus -> Rerank -> LLM -> TTS" + description="End-to-end voice assistant with RAG: STT -> Embeddings -> Milvus -> Rerank -> LLM -> TTS", ) -def voice_assistant_pipeline( - audio_b64: str, - collection_name: str = "knowledge_base" -): +def voice_assistant_pipeline(audio_b64: str, collection_name: str = "knowledge_base"): """ Voice Assistant Pipeline with RAG - + Args: audio_b64: Base64-encoded audio file collection_name: Milvus collection for RAG """ - + # Step 1: Transcribe audio with Whisper transcribe_task = transcribe_audio(audio_b64=audio_b64) transcribe_task.set_caching_options(enable_caching=False) @@ -233,70 +199,47 @@ def voice_assistant_pipeline( embed_task.set_caching_options(enable_caching=True) # Step 3: Retrieve context from Milvus - retrieve_task = retrieve_context( - embedding=embed_task.output, - collection_name=collection_name - ) + retrieve_task = retrieve_context(embedding=embed_task.output, collection_name=collection_name) # Step 4: Rerank documents - rerank_task = rerank_documents( - query=transcribe_task.output, - documents=retrieve_task.output - ) + rerank_task = rerank_documents(query=transcribe_task.output, documents=retrieve_task.output) # Step 5: Generate response with context - llm_task = generate_response( - query=transcribe_task.output, - context=rerank_task.output - ) + llm_task = generate_response(query=transcribe_task.output, context=rerank_task.output) # Step 6: Synthesize speech - tts_task = synthesize_speech(text=llm_task.output) + tts_task = synthesize_speech(text=llm_task.output) # noqa: F841 -@dsl.pipeline( - name="text-to-speech-pipeline", - description="Simple text to speech pipeline" -) +@dsl.pipeline(name="text-to-speech-pipeline", description="Simple text to speech pipeline") def text_to_speech_pipeline(text: str): """Simple TTS pipeline for testing.""" - tts_task = synthesize_speech(text=text) + tts_task = synthesize_speech(text=text) # noqa: F841 @dsl.pipeline( - name="rag-query-pipeline", - description="RAG query pipeline: Embed -> Retrieve -> Rerank -> LLM" + name="rag-query-pipeline", description="RAG query pipeline: Embed -> Retrieve -> Rerank -> LLM" ) -def rag_query_pipeline( - query: str, - collection_name: str = "knowledge_base" -): +def rag_query_pipeline(query: str, collection_name: str = "knowledge_base"): """ RAG Query Pipeline (text input, no voice) - + Args: query: Text query collection_name: Milvus collection name """ # Embed the query embed_task = generate_embeddings(text=query) - + # Retrieve from Milvus - retrieve_task = retrieve_context( - embedding=embed_task.output, - collection_name=collection_name - ) - + retrieve_task = retrieve_context(embedding=embed_task.output, collection_name=collection_name) + # Rerank - rerank_task = rerank_documents( - query=query, - documents=retrieve_task.output - ) - + rerank_task = rerank_documents(query=query, documents=retrieve_task.output) + # Generate response - llm_task = generate_response( - query=query, - context=rerank_task.output + llm_task = generate_response( # noqa: F841 + query=query, context=rerank_task.output ) @@ -307,10 +250,10 @@ if __name__ == "__main__": ("tts_pipeline.yaml", text_to_speech_pipeline), ("rag_pipeline.yaml", rag_query_pipeline), ] - + for filename, pipeline_func in pipelines: compiler.Compiler().compile(pipeline_func, filename) print(f"Compiled: {filename}") - + print("\nUpload these YAML files to Kubeflow Pipelines UI at:") print(" http://kubeflow.example.com/pipelines") diff --git a/pyproject.toml b/pyproject.toml index 8b22e63..78f070f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,6 +22,9 @@ dev = [ requires = ["hatchling"] build-backend = "hatchling.build" +[tool.hatch.metadata] +allow-direct-references = true + [tool.hatch.build.targets.wheel] packages = ["."] only-include = ["voice_assistant.py"] diff --git a/tests/conftest.py b/tests/conftest.py index 7a2939b..2176d48 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,11 +1,11 @@ """ Pytest configuration and fixtures for voice-assistant tests. """ + import asyncio import base64 import os -from typing import AsyncGenerator -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import MagicMock, patch import pytest @@ -29,21 +29,54 @@ def event_loop(): def sample_audio_b64(): """Sample base64 encoded audio for testing.""" # 16-bit PCM silence (44 bytes header + 1000 samples) - wav_header = bytes([ - 0x52, 0x49, 0x46, 0x46, # "RIFF" - 0x24, 0x08, 0x00, 0x00, # File size - 0x57, 0x41, 0x56, 0x45, # "WAVE" - 0x66, 0x6D, 0x74, 0x20, # "fmt " - 0x10, 0x00, 0x00, 0x00, # Chunk size - 0x01, 0x00, # PCM format - 0x01, 0x00, # Mono - 0x80, 0x3E, 0x00, 0x00, # Sample rate (16000) - 0x00, 0x7D, 0x00, 0x00, # Byte rate - 0x02, 0x00, # Block align - 0x10, 0x00, # Bits per sample - 0x64, 0x61, 0x74, 0x61, # "data" - 0x00, 0x08, 0x00, 0x00, # Data size - ]) + wav_header = bytes( + [ + 0x52, + 0x49, + 0x46, + 0x46, # "RIFF" + 0x24, + 0x08, + 0x00, + 0x00, # File size + 0x57, + 0x41, + 0x56, + 0x45, # "WAVE" + 0x66, + 0x6D, + 0x74, + 0x20, # "fmt " + 0x10, + 0x00, + 0x00, + 0x00, # Chunk size + 0x01, + 0x00, # PCM format + 0x01, + 0x00, # Mono + 0x80, + 0x3E, + 0x00, + 0x00, # Sample rate (16000) + 0x00, + 0x7D, + 0x00, + 0x00, # Byte rate + 0x02, + 0x00, # Block align + 0x10, + 0x00, # Bits per sample + 0x64, + 0x61, + 0x74, + 0x61, # "data" + 0x00, + 0x08, + 0x00, + 0x00, # Data size + ] + ) silence = bytes([0x00] * 2048) return base64.b64encode(wav_header + silence).decode() @@ -96,13 +129,14 @@ def mock_voice_request(sample_audio_b64): @pytest.fixture def mock_clients(): """Mock all service clients.""" - with patch("voice_assistant.STTClient") as stt, \ - patch("voice_assistant.EmbeddingsClient") as embeddings, \ - patch("voice_assistant.RerankerClient") as reranker, \ - patch("voice_assistant.LLMClient") as llm, \ - patch("voice_assistant.TTSClient") as tts, \ - patch("voice_assistant.MilvusClient") as milvus: - + with ( + patch("voice_assistant.STTClient") as stt, + patch("voice_assistant.EmbeddingsClient") as embeddings, + patch("voice_assistant.RerankerClient") as reranker, + patch("voice_assistant.LLMClient") as llm, + patch("voice_assistant.TTSClient") as tts, + patch("voice_assistant.MilvusClient") as milvus, + ): yield { "stt": stt, "embeddings": embeddings, diff --git a/tests/test_voice_assistant.py b/tests/test_voice_assistant.py index 69de754..151cc39 100644 --- a/tests/test_voice_assistant.py +++ b/tests/test_voice_assistant.py @@ -1,9 +1,9 @@ """ Unit tests for VoiceAssistant handler. """ -import base64 + import pytest -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import AsyncMock, patch # Import after environment is set up in conftest from voice_assistant import VoiceAssistant, VoiceSettings @@ -11,11 +11,11 @@ from voice_assistant import VoiceAssistant, VoiceSettings class TestVoiceSettings: """Tests for VoiceSettings configuration.""" - + def test_default_settings(self): """Test default settings values.""" settings = VoiceSettings() - + assert settings.service_name == "voice-assistant" assert settings.rag_top_k == 10 assert settings.rag_rerank_top_k == 5 @@ -24,37 +24,35 @@ class TestVoiceSettings: assert settings.tts_language == "en" assert settings.include_transcription is True assert settings.include_sources is False - + def test_custom_settings(self, monkeypatch): """Test settings from environment.""" monkeypatch.setenv("RAG_TOP_K", "20") monkeypatch.setenv("RAG_COLLECTION", "custom_collection") - + # Note: Would need to re-instantiate settings to pick up env vars - settings = VoiceSettings( - rag_top_k=20, - rag_collection="custom_collection" - ) - + settings = VoiceSettings(rag_top_k=20, rag_collection="custom_collection") + assert settings.rag_top_k == 20 assert settings.rag_collection == "custom_collection" class TestVoiceAssistant: """Tests for VoiceAssistant handler.""" - + @pytest.fixture def handler(self): """Create handler with mocked clients.""" - with patch("voice_assistant.STTClient"), \ - patch("voice_assistant.EmbeddingsClient"), \ - patch("voice_assistant.RerankerClient"), \ - patch("voice_assistant.LLMClient"), \ - patch("voice_assistant.TTSClient"), \ - patch("voice_assistant.MilvusClient"): - + with ( + patch("voice_assistant.STTClient"), + patch("voice_assistant.EmbeddingsClient"), + patch("voice_assistant.RerankerClient"), + patch("voice_assistant.LLMClient"), + patch("voice_assistant.TTSClient"), + patch("voice_assistant.MilvusClient"), + ): handler = VoiceAssistant() - + # Setup mock clients handler.stt = AsyncMock() handler.embeddings = AsyncMock() @@ -63,15 +61,15 @@ class TestVoiceAssistant: handler.tts = AsyncMock() handler.milvus = AsyncMock() handler.nats = AsyncMock() - + yield handler - + def test_init(self, handler): """Test handler initialization.""" assert handler.subject == "voice.request" assert handler.queue_group == "voice-assistants" assert handler.voice_settings.service_name == "voice-assistant" - + @pytest.mark.asyncio async def test_handle_message_success( self, @@ -90,16 +88,16 @@ class TestVoiceAssistant: handler.reranker.rerank.return_value = sample_reranked handler.llm.generate.return_value = "Machine learning is a type of AI." handler.tts.synthesize.return_value = b"audio_bytes" - + # Execute result = await handler.handle_message(mock_nats_message, mock_voice_request) - + # Verify assert result["request_id"] == "test-request-123" assert result["response"] == "Machine learning is a type of AI." assert "audio" in result assert result["transcription"] == "What is machine learning?" - + # Verify pipeline was called handler.stt.transcribe.assert_called_once() handler.embeddings.embed_single.assert_called_once() @@ -107,7 +105,7 @@ class TestVoiceAssistant: handler.reranker.rerank.assert_called_once() handler.llm.generate.assert_called_once() handler.tts.synthesize.assert_called_once() - + @pytest.mark.asyncio async def test_handle_message_empty_transcription( self, @@ -117,15 +115,15 @@ class TestVoiceAssistant: ): """Test handling when transcription is empty.""" handler.stt.transcribe.return_value = {"text": ""} - + result = await handler.handle_message(mock_nats_message, mock_voice_request) - + assert "error" in result assert result["error"] == "Could not transcribe audio" - + # Verify pipeline stopped after transcription handler.embeddings.embed_single.assert_not_called() - + @pytest.mark.asyncio async def test_handle_message_with_sources( self, @@ -138,7 +136,7 @@ class TestVoiceAssistant: ): """Test response includes sources when enabled.""" handler.voice_settings.include_sources = True - + # Setup mocks handler.stt.transcribe.return_value = {"text": "Hello"} handler.embeddings.embed_single.return_value = sample_embedding @@ -146,51 +144,52 @@ class TestVoiceAssistant: handler.reranker.rerank.return_value = sample_reranked handler.llm.generate.return_value = "Hi there!" handler.tts.synthesize.return_value = b"audio" - + result = await handler.handle_message(mock_nats_message, mock_voice_request) - + assert "sources" in result assert len(result["sources"]) <= 3 - + def test_build_context(self, handler): """Test context building from documents.""" documents = [ {"document": "First doc content"}, {"document": "Second doc content"}, ] - + context = handler._build_context(documents) - + assert "First doc content" in context assert "Second doc content" in context - + @pytest.mark.asyncio async def test_setup_initializes_clients(self): """Test that setup initializes all clients.""" - with patch("voice_assistant.STTClient") as stt_cls, \ - patch("voice_assistant.EmbeddingsClient") as emb_cls, \ - patch("voice_assistant.RerankerClient") as rer_cls, \ - patch("voice_assistant.LLMClient") as llm_cls, \ - patch("voice_assistant.TTSClient") as tts_cls, \ - patch("voice_assistant.MilvusClient") as mil_cls: - + with ( + patch("voice_assistant.STTClient") as stt_cls, + patch("voice_assistant.EmbeddingsClient") as emb_cls, + patch("voice_assistant.RerankerClient") as rer_cls, + patch("voice_assistant.LLMClient") as llm_cls, + patch("voice_assistant.TTSClient") as tts_cls, + patch("voice_assistant.MilvusClient") as mil_cls, + ): mil_cls.return_value.connect = AsyncMock() - + handler = VoiceAssistant() await handler.setup() - + stt_cls.assert_called_once() emb_cls.assert_called_once() rer_cls.assert_called_once() llm_cls.assert_called_once() tts_cls.assert_called_once() mil_cls.assert_called_once() - + @pytest.mark.asyncio async def test_teardown_closes_clients(self, handler): """Test that teardown closes all clients.""" await handler.teardown() - + handler.stt.close.assert_called_once() handler.embeddings.close.assert_called_once() handler.reranker.close.assert_called_once() diff --git a/voice_assistant.py b/voice_assistant.py index de04ead..4b35eeb 100644 --- a/voice_assistant.py +++ b/voice_assistant.py @@ -12,6 +12,7 @@ End-to-end voice assistant pipeline using handler-base: 7. Synthesize speech with XTTS 8. Publish result to NATS "voice.response.{request_id}" """ + import base64 import logging from typing import Any, Optional @@ -34,18 +35,18 @@ logger = logging.getLogger("voice-assistant") class VoiceSettings(Settings): """Voice assistant specific settings.""" - + service_name: str = "voice-assistant" - + # RAG settings rag_top_k: int = 10 rag_rerank_top_k: int = 5 rag_collection: str = "documents" - + # Audio settings stt_language: Optional[str] = None # Auto-detect tts_language: str = "en" - + # Response settings include_transcription: bool = True include_sources: bool = False @@ -54,7 +55,7 @@ class VoiceSettings(Settings): class VoiceAssistant(Handler): """ Voice request handler with full STT -> RAG -> LLM -> TTS pipeline. - + Request format (msgpack): { "request_id": "uuid", @@ -62,7 +63,7 @@ class VoiceAssistant(Handler): "language": "optional language code", "collection": "optional collection name" } - + Response format: { "request_id": "uuid", @@ -71,7 +72,7 @@ class VoiceAssistant(Handler): "audio": "base64 encoded response audio" } """ - + def __init__(self): self.voice_settings = VoiceSettings() super().__init__( @@ -79,121 +80,116 @@ class VoiceAssistant(Handler): settings=self.voice_settings, queue_group="voice-assistants", ) - + async def setup(self) -> None: """Initialize service clients.""" logger.info("Initializing voice assistant clients...") - + self.stt = STTClient(self.voice_settings) self.embeddings = EmbeddingsClient(self.voice_settings) self.reranker = RerankerClient(self.voice_settings) self.llm = LLMClient(self.voice_settings) self.tts = TTSClient(self.voice_settings) self.milvus = MilvusClient(self.voice_settings) - + await self.milvus.connect(self.voice_settings.rag_collection) - + logger.info("Voice assistant clients initialized") - + async def teardown(self) -> None: """Clean up service clients.""" logger.info("Closing voice assistant clients...") - + await self.stt.close() await self.embeddings.close() await self.reranker.close() await self.llm.close() await self.tts.close() await self.milvus.close() - + logger.info("Voice assistant clients closed") - + async def handle_message(self, msg: Msg, data: Any) -> Optional[dict]: """Handle incoming voice request.""" request_id = data.get("request_id", "unknown") audio_b64 = data.get("audio", "") language = data.get("language", self.voice_settings.stt_language) collection = data.get("collection", self.voice_settings.rag_collection) - + logger.info(f"Processing voice request {request_id}") - + with create_span("voice.process") as span: if span: span.set_attribute("request.id", request_id) - + # 1. Decode audio audio_bytes = base64.b64decode(audio_b64) - + # 2. Transcribe audio to text transcription = await self._transcribe(audio_bytes, language) query = transcription.get("text", "") - + if not query.strip(): logger.warning(f"Empty transcription for request {request_id}") return { "request_id": request_id, "error": "Could not transcribe audio", } - + logger.info(f"Transcribed: {query[:50]}...") - + # 3. Generate query embedding embedding = await self._get_embedding(query) - + # 4. Search Milvus for context documents = await self._search_context(embedding, collection) - + # 5. Rerank documents reranked = await self._rerank_documents(query, documents) - + # 6. Build context context = self._build_context(reranked) - + # 7. Generate LLM response response_text = await self._generate_response(query, context) - + # 8. Synthesize speech response_audio = await self._synthesize_speech(response_text) - + # Build response result = { "request_id": request_id, "response": response_text, "audio": response_audio, } - + if self.voice_settings.include_transcription: result["transcription"] = query - + if self.voice_settings.include_sources: result["sources"] = [ - {"text": d["document"][:200], "score": d["score"]} - for d in reranked[:3] + {"text": d["document"][:200], "score": d["score"]} for d in reranked[:3] ] - + logger.info(f"Completed voice request {request_id}") - + # Publish to response subject response_subject = f"voice.response.{request_id}" await self.nats.publish(response_subject, result) - + return result - - async def _transcribe( - self, audio: bytes, language: Optional[str] - ) -> dict: + + async def _transcribe(self, audio: bytes, language: Optional[str]) -> dict: """Transcribe audio to text.""" with create_span("voice.stt"): return await self.stt.transcribe(audio, language=language) - + async def _get_embedding(self, text: str) -> list[float]: """Generate embedding for query text.""" with create_span("voice.embedding"): return await self.embeddings.embed_single(text) - - async def _search_context( - self, embedding: list[float], collection: str - ) -> list[dict]: + + async def _search_context(self, embedding: list[float], collection: str) -> list[dict]: """Search Milvus for relevant documents.""" with create_span("voice.search"): return await self.milvus.search_with_texts( @@ -201,32 +197,28 @@ class VoiceAssistant(Handler): limit=self.voice_settings.rag_top_k, text_field="text", ) - - async def _rerank_documents( - self, query: str, documents: list[dict] - ) -> list[dict]: + + async def _rerank_documents(self, query: str, documents: list[dict]) -> list[dict]: """Rerank documents by relevance.""" with create_span("voice.rerank"): texts = [d.get("text", "") for d in documents] return await self.reranker.rerank( query, texts, top_k=self.voice_settings.rag_rerank_top_k ) - + def _build_context(self, documents: list[dict]) -> str: """Build context string from ranked documents.""" return "\n\n".join(d.get("document", "") for d in documents) - + async def _generate_response(self, query: str, context: str) -> str: """Generate LLM response.""" with create_span("voice.generate"): return await self.llm.generate(query, context=context) - + async def _synthesize_speech(self, text: str) -> str: """Synthesize speech and return base64.""" with create_span("voice.tts"): - audio_bytes = await self.tts.synthesize( - text, language=self.voice_settings.tts_language - ) + audio_bytes = await self.tts.synthesize(text, language=self.voice_settings.tts_language) return base64.b64encode(audio_bytes).decode()