fix: ruff formatting, allow-direct-references, and noqa for Kubeflow pipeline vars
All checks were successful
CI / Lint (push) Successful in 51s
CI / Test (push) Successful in 55s
CI / Release (push) Successful in 7s
CI / Notify (push) Successful in 2s

This commit is contained in:
2026-02-02 08:44:14 -05:00
parent 58465b77d8
commit 0462412353
5 changed files with 208 additions and 237 deletions

View File

@@ -14,13 +14,9 @@ from kfp import dsl
from kfp import compiler from kfp import compiler
@dsl.component( @dsl.component(base_image="python:3.13-slim", packages_to_install=["httpx"])
base_image="python:3.13-slim",
packages_to_install=["httpx"]
)
def transcribe_audio( def transcribe_audio(
audio_b64: str, audio_b64: str, whisper_url: str = "http://whisper-predictor.ai-ml.svc.cluster.local"
whisper_url: str = "http://whisper-predictor.ai-ml.svc.cluster.local"
) -> str: ) -> str:
"""Transcribe audio using Whisper STT service.""" """Transcribe audio using Whisper STT service."""
import base64 import base64
@@ -32,43 +28,35 @@ def transcribe_audio(
response = client.post( response = client.post(
f"{whisper_url}/v1/audio/transcriptions", f"{whisper_url}/v1/audio/transcriptions",
files={"file": ("audio.wav", audio_bytes, "audio/wav")}, files={"file": ("audio.wav", audio_bytes, "audio/wav")},
data={"model": "whisper-large-v3", "language": "en"} data={"model": "whisper-large-v3", "language": "en"},
) )
result = response.json() result = response.json()
return result.get("text", "") return result.get("text", "")
@dsl.component( @dsl.component(base_image="python:3.13-slim", packages_to_install=["httpx"])
base_image="python:3.13-slim",
packages_to_install=["httpx"]
)
def generate_embeddings( def generate_embeddings(
text: str, text: str, embeddings_url: str = "http://embeddings-predictor.ai-ml.svc.cluster.local"
embeddings_url: str = "http://embeddings-predictor.ai-ml.svc.cluster.local"
) -> list: ) -> list:
"""Generate embeddings for RAG retrieval.""" """Generate embeddings for RAG retrieval."""
import httpx import httpx
with httpx.Client(timeout=60.0) as client: with httpx.Client(timeout=60.0) as client:
response = client.post( response = client.post(
f"{embeddings_url}/embeddings", f"{embeddings_url}/embeddings", json={"input": text, "model": "bge-small-en-v1.5"}
json={"input": text, "model": "bge-small-en-v1.5"}
) )
result = response.json() result = response.json()
return result["data"][0]["embedding"] return result["data"][0]["embedding"]
@dsl.component( @dsl.component(base_image="python:3.13-slim", packages_to_install=["pymilvus"])
base_image="python:3.13-slim",
packages_to_install=["pymilvus"]
)
def retrieve_context( def retrieve_context(
embedding: list, embedding: list,
milvus_host: str = "milvus.ai-ml.svc.cluster.local", milvus_host: str = "milvus.ai-ml.svc.cluster.local",
collection_name: str = "knowledge_base", collection_name: str = "knowledge_base",
top_k: int = 5 top_k: int = 5,
) -> list: ) -> list:
"""Retrieve relevant documents from Milvus vector database.""" """Retrieve relevant documents from Milvus vector database."""
from pymilvus import connections, Collection, utility from pymilvus import connections, Collection, utility
@@ -86,30 +74,29 @@ def retrieve_context(
anns_field="embedding", anns_field="embedding",
param={"metric_type": "COSINE", "params": {"nprobe": 10}}, param={"metric_type": "COSINE", "params": {"nprobe": 10}},
limit=top_k, limit=top_k,
output_fields=["text", "source"] output_fields=["text", "source"],
) )
documents = [] documents = []
for hits in results: for hits in results:
for hit in hits: for hit in hits:
documents.append({ documents.append(
{
"text": hit.entity.get("text"), "text": hit.entity.get("text"),
"source": hit.entity.get("source"), "source": hit.entity.get("source"),
"score": hit.distance "score": hit.distance,
}) }
)
return documents return documents
@dsl.component( @dsl.component(base_image="python:3.13-slim", packages_to_install=["httpx"])
base_image="python:3.13-slim",
packages_to_install=["httpx"]
)
def rerank_documents( def rerank_documents(
query: str, query: str,
documents: list, documents: list,
reranker_url: str = "http://reranker-predictor.ai-ml.svc.cluster.local", reranker_url: str = "http://reranker-predictor.ai-ml.svc.cluster.local",
top_k: int = 3 top_k: int = 3,
) -> list: ) -> list:
"""Rerank documents using BGE reranker.""" """Rerank documents using BGE reranker."""
import httpx import httpx
@@ -123,30 +110,25 @@ def rerank_documents(
json={ json={
"query": query, "query": query,
"documents": [doc["text"] for doc in documents], "documents": [doc["text"] for doc in documents],
"model": "bge-reranker-v2-m3" "model": "bge-reranker-v2-m3",
} },
) )
result = response.json() result = response.json()
# Sort by rerank score # Sort by rerank score
reranked = sorted( reranked = sorted(
zip(documents, result.get("scores", [0] * len(documents))), zip(documents, result.get("scores", [0] * len(documents))), key=lambda x: x[1], reverse=True
key=lambda x: x[1],
reverse=True
)[:top_k] )[:top_k]
return [doc for doc, score in reranked] return [doc for doc, score in reranked]
@dsl.component( @dsl.component(base_image="python:3.13-slim", packages_to_install=["httpx"])
base_image="python:3.13-slim",
packages_to_install=["httpx"]
)
def generate_response( def generate_response(
query: str, query: str,
context: list, context: list,
vllm_url: str = "http://llm-draft.ai-ml.svc.cluster.local:8000", vllm_url: str = "http://llm-draft.ai-ml.svc.cluster.local:8000",
model: str = "mistralai/Mistral-7B-Instruct-v0.3" model: str = "mistralai/Mistral-7B-Instruct-v0.3",
) -> str: ) -> str:
"""Generate response using vLLM.""" """Generate response using vLLM."""
import httpx import httpx
@@ -164,31 +146,22 @@ Keep responses concise and natural for speech synthesis."""
messages = [ messages = [
{"role": "system", "content": system_prompt}, {"role": "system", "content": system_prompt},
{"role": "user", "content": user_content} {"role": "user", "content": user_content},
] ]
with httpx.Client(timeout=180.0) as client: with httpx.Client(timeout=180.0) as client:
response = client.post( response = client.post(
f"{vllm_url}/v1/chat/completions", f"{vllm_url}/v1/chat/completions",
json={ json={"model": model, "messages": messages, "max_tokens": 512, "temperature": 0.7},
"model": model,
"messages": messages,
"max_tokens": 512,
"temperature": 0.7
}
) )
result = response.json() result = response.json()
return result["choices"][0]["message"]["content"] return result["choices"][0]["message"]["content"]
@dsl.component( @dsl.component(base_image="python:3.13-slim", packages_to_install=["httpx"])
base_image="python:3.13-slim",
packages_to_install=["httpx"]
)
def synthesize_speech( def synthesize_speech(
text: str, text: str, tts_url: str = "http://tts-predictor.ai-ml.svc.cluster.local"
tts_url: str = "http://tts-predictor.ai-ml.svc.cluster.local"
) -> str: ) -> str:
"""Convert text to speech using TTS service.""" """Convert text to speech using TTS service."""
import base64 import base64
@@ -197,11 +170,7 @@ def synthesize_speech(
with httpx.Client(timeout=120.0) as client: with httpx.Client(timeout=120.0) as client:
response = client.post( response = client.post(
f"{tts_url}/v1/audio/speech", f"{tts_url}/v1/audio/speech",
json={ json={"input": text, "voice": "en_US-lessac-high", "response_format": "wav"},
"input": text,
"voice": "en_US-lessac-high",
"response_format": "wav"
}
) )
audio_b64 = base64.b64encode(response.content).decode("utf-8") audio_b64 = base64.b64encode(response.content).decode("utf-8")
@@ -210,12 +179,9 @@ def synthesize_speech(
@dsl.pipeline( @dsl.pipeline(
name="voice-assistant-rag-pipeline", name="voice-assistant-rag-pipeline",
description="End-to-end voice assistant with RAG: STT -> Embeddings -> Milvus -> Rerank -> LLM -> TTS" description="End-to-end voice assistant with RAG: STT -> Embeddings -> Milvus -> Rerank -> LLM -> TTS",
) )
def voice_assistant_pipeline( def voice_assistant_pipeline(audio_b64: str, collection_name: str = "knowledge_base"):
audio_b64: str,
collection_name: str = "knowledge_base"
):
""" """
Voice Assistant Pipeline with RAG Voice Assistant Pipeline with RAG
@@ -233,44 +199,28 @@ def voice_assistant_pipeline(
embed_task.set_caching_options(enable_caching=True) embed_task.set_caching_options(enable_caching=True)
# Step 3: Retrieve context from Milvus # Step 3: Retrieve context from Milvus
retrieve_task = retrieve_context( retrieve_task = retrieve_context(embedding=embed_task.output, collection_name=collection_name)
embedding=embed_task.output,
collection_name=collection_name
)
# Step 4: Rerank documents # Step 4: Rerank documents
rerank_task = rerank_documents( rerank_task = rerank_documents(query=transcribe_task.output, documents=retrieve_task.output)
query=transcribe_task.output,
documents=retrieve_task.output
)
# Step 5: Generate response with context # Step 5: Generate response with context
llm_task = generate_response( llm_task = generate_response(query=transcribe_task.output, context=rerank_task.output)
query=transcribe_task.output,
context=rerank_task.output
)
# Step 6: Synthesize speech # Step 6: Synthesize speech
tts_task = synthesize_speech(text=llm_task.output) tts_task = synthesize_speech(text=llm_task.output) # noqa: F841
@dsl.pipeline( @dsl.pipeline(name="text-to-speech-pipeline", description="Simple text to speech pipeline")
name="text-to-speech-pipeline",
description="Simple text to speech pipeline"
)
def text_to_speech_pipeline(text: str): def text_to_speech_pipeline(text: str):
"""Simple TTS pipeline for testing.""" """Simple TTS pipeline for testing."""
tts_task = synthesize_speech(text=text) tts_task = synthesize_speech(text=text) # noqa: F841
@dsl.pipeline( @dsl.pipeline(
name="rag-query-pipeline", name="rag-query-pipeline", description="RAG query pipeline: Embed -> Retrieve -> Rerank -> LLM"
description="RAG query pipeline: Embed -> Retrieve -> Rerank -> LLM"
) )
def rag_query_pipeline( def rag_query_pipeline(query: str, collection_name: str = "knowledge_base"):
query: str,
collection_name: str = "knowledge_base"
):
""" """
RAG Query Pipeline (text input, no voice) RAG Query Pipeline (text input, no voice)
@@ -282,21 +232,14 @@ def rag_query_pipeline(
embed_task = generate_embeddings(text=query) embed_task = generate_embeddings(text=query)
# Retrieve from Milvus # Retrieve from Milvus
retrieve_task = retrieve_context( retrieve_task = retrieve_context(embedding=embed_task.output, collection_name=collection_name)
embedding=embed_task.output,
collection_name=collection_name
)
# Rerank # Rerank
rerank_task = rerank_documents( rerank_task = rerank_documents(query=query, documents=retrieve_task.output)
query=query,
documents=retrieve_task.output
)
# Generate response # Generate response
llm_task = generate_response( llm_task = generate_response( # noqa: F841
query=query, query=query, context=rerank_task.output
context=rerank_task.output
) )

View File

@@ -22,6 +22,9 @@ dev = [
requires = ["hatchling"] requires = ["hatchling"]
build-backend = "hatchling.build" build-backend = "hatchling.build"
[tool.hatch.metadata]
allow-direct-references = true
[tool.hatch.build.targets.wheel] [tool.hatch.build.targets.wheel]
packages = ["."] packages = ["."]
only-include = ["voice_assistant.py"] only-include = ["voice_assistant.py"]

View File

@@ -1,11 +1,11 @@
""" """
Pytest configuration and fixtures for voice-assistant tests. Pytest configuration and fixtures for voice-assistant tests.
""" """
import asyncio import asyncio
import base64 import base64
import os import os
from typing import AsyncGenerator from unittest.mock import MagicMock, patch
from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
@@ -29,21 +29,54 @@ def event_loop():
def sample_audio_b64(): def sample_audio_b64():
"""Sample base64 encoded audio for testing.""" """Sample base64 encoded audio for testing."""
# 16-bit PCM silence (44 bytes header + 1000 samples) # 16-bit PCM silence (44 bytes header + 1000 samples)
wav_header = bytes([ wav_header = bytes(
0x52, 0x49, 0x46, 0x46, # "RIFF" [
0x24, 0x08, 0x00, 0x00, # File size 0x52,
0x57, 0x41, 0x56, 0x45, # "WAVE" 0x49,
0x66, 0x6D, 0x74, 0x20, # "fmt " 0x46,
0x10, 0x00, 0x00, 0x00, # Chunk size 0x46, # "RIFF"
0x01, 0x00, # PCM format 0x24,
0x01, 0x00, # Mono 0x08,
0x80, 0x3E, 0x00, 0x00, # Sample rate (16000) 0x00,
0x00, 0x7D, 0x00, 0x00, # Byte rate 0x00, # File size
0x02, 0x00, # Block align 0x57,
0x10, 0x00, # Bits per sample 0x41,
0x64, 0x61, 0x74, 0x61, # "data" 0x56,
0x00, 0x08, 0x00, 0x00, # Data size 0x45, # "WAVE"
]) 0x66,
0x6D,
0x74,
0x20, # "fmt "
0x10,
0x00,
0x00,
0x00, # Chunk size
0x01,
0x00, # PCM format
0x01,
0x00, # Mono
0x80,
0x3E,
0x00,
0x00, # Sample rate (16000)
0x00,
0x7D,
0x00,
0x00, # Byte rate
0x02,
0x00, # Block align
0x10,
0x00, # Bits per sample
0x64,
0x61,
0x74,
0x61, # "data"
0x00,
0x08,
0x00,
0x00, # Data size
]
)
silence = bytes([0x00] * 2048) silence = bytes([0x00] * 2048)
return base64.b64encode(wav_header + silence).decode() return base64.b64encode(wav_header + silence).decode()
@@ -96,13 +129,14 @@ def mock_voice_request(sample_audio_b64):
@pytest.fixture @pytest.fixture
def mock_clients(): def mock_clients():
"""Mock all service clients.""" """Mock all service clients."""
with patch("voice_assistant.STTClient") as stt, \ with (
patch("voice_assistant.EmbeddingsClient") as embeddings, \ patch("voice_assistant.STTClient") as stt,
patch("voice_assistant.RerankerClient") as reranker, \ patch("voice_assistant.EmbeddingsClient") as embeddings,
patch("voice_assistant.LLMClient") as llm, \ patch("voice_assistant.RerankerClient") as reranker,
patch("voice_assistant.TTSClient") as tts, \ patch("voice_assistant.LLMClient") as llm,
patch("voice_assistant.MilvusClient") as milvus: patch("voice_assistant.TTSClient") as tts,
patch("voice_assistant.MilvusClient") as milvus,
):
yield { yield {
"stt": stt, "stt": stt,
"embeddings": embeddings, "embeddings": embeddings,

View File

@@ -1,9 +1,9 @@
""" """
Unit tests for VoiceAssistant handler. Unit tests for VoiceAssistant handler.
""" """
import base64
import pytest import pytest
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, patch
# Import after environment is set up in conftest # Import after environment is set up in conftest
from voice_assistant import VoiceAssistant, VoiceSettings from voice_assistant import VoiceAssistant, VoiceSettings
@@ -31,10 +31,7 @@ class TestVoiceSettings:
monkeypatch.setenv("RAG_COLLECTION", "custom_collection") monkeypatch.setenv("RAG_COLLECTION", "custom_collection")
# Note: Would need to re-instantiate settings to pick up env vars # Note: Would need to re-instantiate settings to pick up env vars
settings = VoiceSettings( settings = VoiceSettings(rag_top_k=20, rag_collection="custom_collection")
rag_top_k=20,
rag_collection="custom_collection"
)
assert settings.rag_top_k == 20 assert settings.rag_top_k == 20
assert settings.rag_collection == "custom_collection" assert settings.rag_collection == "custom_collection"
@@ -46,13 +43,14 @@ class TestVoiceAssistant:
@pytest.fixture @pytest.fixture
def handler(self): def handler(self):
"""Create handler with mocked clients.""" """Create handler with mocked clients."""
with patch("voice_assistant.STTClient"), \ with (
patch("voice_assistant.EmbeddingsClient"), \ patch("voice_assistant.STTClient"),
patch("voice_assistant.RerankerClient"), \ patch("voice_assistant.EmbeddingsClient"),
patch("voice_assistant.LLMClient"), \ patch("voice_assistant.RerankerClient"),
patch("voice_assistant.TTSClient"), \ patch("voice_assistant.LLMClient"),
patch("voice_assistant.MilvusClient"): patch("voice_assistant.TTSClient"),
patch("voice_assistant.MilvusClient"),
):
handler = VoiceAssistant() handler = VoiceAssistant()
# Setup mock clients # Setup mock clients
@@ -167,13 +165,14 @@ class TestVoiceAssistant:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_setup_initializes_clients(self): async def test_setup_initializes_clients(self):
"""Test that setup initializes all clients.""" """Test that setup initializes all clients."""
with patch("voice_assistant.STTClient") as stt_cls, \ with (
patch("voice_assistant.EmbeddingsClient") as emb_cls, \ patch("voice_assistant.STTClient") as stt_cls,
patch("voice_assistant.RerankerClient") as rer_cls, \ patch("voice_assistant.EmbeddingsClient") as emb_cls,
patch("voice_assistant.LLMClient") as llm_cls, \ patch("voice_assistant.RerankerClient") as rer_cls,
patch("voice_assistant.TTSClient") as tts_cls, \ patch("voice_assistant.LLMClient") as llm_cls,
patch("voice_assistant.MilvusClient") as mil_cls: patch("voice_assistant.TTSClient") as tts_cls,
patch("voice_assistant.MilvusClient") as mil_cls,
):
mil_cls.return_value.connect = AsyncMock() mil_cls.return_value.connect = AsyncMock()
handler = VoiceAssistant() handler = VoiceAssistant()

View File

@@ -12,6 +12,7 @@ End-to-end voice assistant pipeline using handler-base:
7. Synthesize speech with XTTS 7. Synthesize speech with XTTS
8. Publish result to NATS "voice.response.{request_id}" 8. Publish result to NATS "voice.response.{request_id}"
""" """
import base64 import base64
import logging import logging
from typing import Any, Optional from typing import Any, Optional
@@ -167,8 +168,7 @@ class VoiceAssistant(Handler):
if self.voice_settings.include_sources: if self.voice_settings.include_sources:
result["sources"] = [ result["sources"] = [
{"text": d["document"][:200], "score": d["score"]} {"text": d["document"][:200], "score": d["score"]} for d in reranked[:3]
for d in reranked[:3]
] ]
logger.info(f"Completed voice request {request_id}") logger.info(f"Completed voice request {request_id}")
@@ -179,9 +179,7 @@ class VoiceAssistant(Handler):
return result return result
async def _transcribe( async def _transcribe(self, audio: bytes, language: Optional[str]) -> dict:
self, audio: bytes, language: Optional[str]
) -> dict:
"""Transcribe audio to text.""" """Transcribe audio to text."""
with create_span("voice.stt"): with create_span("voice.stt"):
return await self.stt.transcribe(audio, language=language) return await self.stt.transcribe(audio, language=language)
@@ -191,9 +189,7 @@ class VoiceAssistant(Handler):
with create_span("voice.embedding"): with create_span("voice.embedding"):
return await self.embeddings.embed_single(text) return await self.embeddings.embed_single(text)
async def _search_context( async def _search_context(self, embedding: list[float], collection: str) -> list[dict]:
self, embedding: list[float], collection: str
) -> list[dict]:
"""Search Milvus for relevant documents.""" """Search Milvus for relevant documents."""
with create_span("voice.search"): with create_span("voice.search"):
return await self.milvus.search_with_texts( return await self.milvus.search_with_texts(
@@ -202,9 +198,7 @@ class VoiceAssistant(Handler):
text_field="text", text_field="text",
) )
async def _rerank_documents( async def _rerank_documents(self, query: str, documents: list[dict]) -> list[dict]:
self, query: str, documents: list[dict]
) -> list[dict]:
"""Rerank documents by relevance.""" """Rerank documents by relevance."""
with create_span("voice.rerank"): with create_span("voice.rerank"):
texts = [d.get("text", "") for d in documents] texts = [d.get("text", "") for d in documents]
@@ -224,9 +218,7 @@ class VoiceAssistant(Handler):
async def _synthesize_speech(self, text: str) -> str: async def _synthesize_speech(self, text: str) -> str:
"""Synthesize speech and return base64.""" """Synthesize speech and return base64."""
with create_span("voice.tts"): with create_span("voice.tts"):
audio_bytes = await self.tts.synthesize( audio_bytes = await self.tts.synthesize(text, language=self.voice_settings.tts_language)
text, language=self.voice_settings.tts_language
)
return base64.b64encode(audio_bytes).decode() return base64.b64encode(audio_bytes).decode()