fix: ruff formatting and allow-direct-references for handler-base dep

This commit is contained in:
2026-02-02 08:44:34 -05:00
parent 09e8135f93
commit 9d26facfaa
4 changed files with 107 additions and 106 deletions

View File

@@ -11,6 +11,7 @@ Text-based chat pipeline using handler-base:
6. Optionally synthesize speech with XTTS 6. Optionally synthesize speech with XTTS
7. Publish result to NATS "ai.chat.response.{request_id}" 7. Publish result to NATS "ai.chat.response.{request_id}"
""" """
import base64 import base64
import logging import logging
from typing import Any, Optional from typing import Any, Optional
@@ -138,9 +139,7 @@ class ChatHandler(Handler):
context = self._build_context(reranked) context = self._build_context(reranked)
# 5. Generate LLM response # 5. Generate LLM response
response_text = await self._generate_response( response_text = await self._generate_response(query, context, system_prompt)
query, context, system_prompt
)
# 6. Optionally synthesize speech # 6. Optionally synthesize speech
audio_b64 = None audio_b64 = None
@@ -155,8 +154,7 @@ class ChatHandler(Handler):
if self.chat_settings.include_sources: if self.chat_settings.include_sources:
result["sources"] = [ result["sources"] = [
{"text": d["document"][:200], "score": d["score"]} {"text": d["document"][:200], "score": d["score"]} for d in reranked[:3]
for d in reranked[:3]
] ]
if audio_b64: if audio_b64:
@@ -175,9 +173,7 @@ class ChatHandler(Handler):
with create_span("chat.embedding"): with create_span("chat.embedding"):
return await self.embeddings.embed_single(text) return await self.embeddings.embed_single(text)
async def _search_context( async def _search_context(self, embedding: list[float], collection: str) -> list[dict]:
self, embedding: list[float], collection: str
) -> list[dict]:
"""Search Milvus for relevant documents.""" """Search Milvus for relevant documents."""
with create_span("chat.search"): with create_span("chat.search"):
return await self.milvus.search_with_texts( return await self.milvus.search_with_texts(
@@ -187,9 +183,7 @@ class ChatHandler(Handler):
metadata_fields=["source", "title"], metadata_fields=["source", "title"],
) )
async def _rerank_documents( async def _rerank_documents(self, query: str, documents: list[dict]) -> list[dict]:
self, query: str, documents: list[dict]
) -> list[dict]:
"""Rerank documents by relevance to query.""" """Rerank documents by relevance to query."""
with create_span("chat.rerank"): with create_span("chat.rerank"):
texts = [d.get("text", "") for d in documents] texts = [d.get("text", "") for d in documents]

View File

@@ -22,6 +22,9 @@ dev = [
requires = ["hatchling"] requires = ["hatchling"]
build-backend = "hatchling.build" build-backend = "hatchling.build"
[tool.hatch.metadata]
allow-direct-references = true
[tool.hatch.build.targets.wheel] [tool.hatch.build.targets.wheel]
packages = ["."] packages = ["."]
only-include = ["chat_handler.py"] only-include = ["chat_handler.py"]

View File

@@ -1,9 +1,10 @@
""" """
Pytest configuration and fixtures for chat-handler tests. Pytest configuration and fixtures for chat-handler tests.
""" """
import asyncio import asyncio
import os import os
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import MagicMock
import pytest import pytest

View File

@@ -1,9 +1,9 @@
""" """
Unit tests for ChatHandler. Unit tests for ChatHandler.
""" """
import base64
import pytest import pytest
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, patch
from chat_handler import ChatHandler, ChatSettings from chat_handler import ChatHandler, ChatSettings
@@ -42,12 +42,13 @@ class TestChatHandler:
@pytest.fixture @pytest.fixture
def handler(self): def handler(self):
"""Create handler with mocked clients.""" """Create handler with mocked clients."""
with patch("chat_handler.EmbeddingsClient"), \ with (
patch("chat_handler.RerankerClient"), \ patch("chat_handler.EmbeddingsClient"),
patch("chat_handler.LLMClient"), \ patch("chat_handler.RerankerClient"),
patch("chat_handler.TTSClient"), \ patch("chat_handler.LLMClient"),
patch("chat_handler.MilvusClient"): patch("chat_handler.TTSClient"),
patch("chat_handler.MilvusClient"),
):
handler = ChatHandler() handler = ChatHandler()
# Setup mock clients # Setup mock clients
@@ -63,12 +64,13 @@ class TestChatHandler:
@pytest.fixture @pytest.fixture
def handler_with_tts(self): def handler_with_tts(self):
"""Create handler with TTS enabled.""" """Create handler with TTS enabled."""
with patch("chat_handler.EmbeddingsClient"), \ with (
patch("chat_handler.RerankerClient"), \ patch("chat_handler.EmbeddingsClient"),
patch("chat_handler.LLMClient"), \ patch("chat_handler.RerankerClient"),
patch("chat_handler.TTSClient"), \ patch("chat_handler.LLMClient"),
patch("chat_handler.MilvusClient"): patch("chat_handler.TTSClient"),
patch("chat_handler.MilvusClient"),
):
handler = ChatHandler() handler = ChatHandler()
handler.chat_settings.enable_tts = True handler.chat_settings.enable_tts = True
@@ -211,12 +213,13 @@ class TestChatHandler:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_setup_initializes_clients(self): async def test_setup_initializes_clients(self):
"""Test that setup initializes all required clients.""" """Test that setup initializes all required clients."""
with patch("chat_handler.EmbeddingsClient") as emb_cls, \ with (
patch("chat_handler.RerankerClient") as rer_cls, \ patch("chat_handler.EmbeddingsClient") as emb_cls,
patch("chat_handler.LLMClient") as llm_cls, \ patch("chat_handler.RerankerClient") as rer_cls,
patch("chat_handler.TTSClient") as tts_cls, \ patch("chat_handler.LLMClient") as llm_cls,
patch("chat_handler.MilvusClient") as mil_cls: patch("chat_handler.TTSClient") as tts_cls,
patch("chat_handler.MilvusClient") as mil_cls,
):
mil_cls.return_value.connect = AsyncMock() mil_cls.return_value.connect = AsyncMock()
handler = ChatHandler() handler = ChatHandler()