fix: ruff formatting and allow-direct-references for handler-base dep

This commit is contained in:
2026-02-02 08:44:34 -05:00
parent 09e8135f93
commit 9d26facfaa
4 changed files with 107 additions and 106 deletions

View File

@@ -11,6 +11,7 @@ Text-based chat pipeline using handler-base:
6. Optionally synthesize speech with XTTS
7. Publish result to NATS "ai.chat.response.{request_id}"
"""
import base64
import logging
from typing import Any, Optional
@@ -138,9 +139,7 @@ class ChatHandler(Handler):
context = self._build_context(reranked)
# 5. Generate LLM response
response_text = await self._generate_response(
query, context, system_prompt
)
response_text = await self._generate_response(query, context, system_prompt)
# 6. Optionally synthesize speech
audio_b64 = None
@@ -155,8 +154,7 @@ class ChatHandler(Handler):
if self.chat_settings.include_sources:
result["sources"] = [
{"text": d["document"][:200], "score": d["score"]}
for d in reranked[:3]
{"text": d["document"][:200], "score": d["score"]} for d in reranked[:3]
]
if audio_b64:
@@ -175,9 +173,7 @@ class ChatHandler(Handler):
with create_span("chat.embedding"):
return await self.embeddings.embed_single(text)
async def _search_context(
self, embedding: list[float], collection: str
) -> list[dict]:
async def _search_context(self, embedding: list[float], collection: str) -> list[dict]:
"""Search Milvus for relevant documents."""
with create_span("chat.search"):
return await self.milvus.search_with_texts(
@@ -187,9 +183,7 @@ class ChatHandler(Handler):
metadata_fields=["source", "title"],
)
async def _rerank_documents(
self, query: str, documents: list[dict]
) -> list[dict]:
async def _rerank_documents(self, query: str, documents: list[dict]) -> list[dict]:
"""Rerank documents by relevance to query."""
with create_span("chat.rerank"):
texts = [d.get("text", "") for d in documents]

View File

@@ -22,6 +22,9 @@ dev = [
requires = ["hatchling"]
build-backend = "hatchling.build"
[tool.hatch.metadata]
allow-direct-references = true
[tool.hatch.build.targets.wheel]
packages = ["."]
only-include = ["chat_handler.py"]

View File

@@ -1,9 +1,10 @@
"""
Pytest configuration and fixtures for chat-handler tests.
"""
import asyncio
import os
from unittest.mock import AsyncMock, MagicMock, patch
from unittest.mock import MagicMock
import pytest

View File

@@ -1,9 +1,9 @@
"""
Unit tests for ChatHandler.
"""
import base64
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from unittest.mock import AsyncMock, patch
from chat_handler import ChatHandler, ChatSettings
@@ -42,12 +42,13 @@ class TestChatHandler:
@pytest.fixture
def handler(self):
"""Create handler with mocked clients."""
with patch("chat_handler.EmbeddingsClient"), \
patch("chat_handler.RerankerClient"), \
patch("chat_handler.LLMClient"), \
patch("chat_handler.TTSClient"), \
patch("chat_handler.MilvusClient"):
with (
patch("chat_handler.EmbeddingsClient"),
patch("chat_handler.RerankerClient"),
patch("chat_handler.LLMClient"),
patch("chat_handler.TTSClient"),
patch("chat_handler.MilvusClient"),
):
handler = ChatHandler()
# Setup mock clients
@@ -63,12 +64,13 @@ class TestChatHandler:
@pytest.fixture
def handler_with_tts(self):
"""Create handler with TTS enabled."""
with patch("chat_handler.EmbeddingsClient"), \
patch("chat_handler.RerankerClient"), \
patch("chat_handler.LLMClient"), \
patch("chat_handler.TTSClient"), \
patch("chat_handler.MilvusClient"):
with (
patch("chat_handler.EmbeddingsClient"),
patch("chat_handler.RerankerClient"),
patch("chat_handler.LLMClient"),
patch("chat_handler.TTSClient"),
patch("chat_handler.MilvusClient"),
):
handler = ChatHandler()
handler.chat_settings.enable_tts = True
@@ -211,12 +213,13 @@ class TestChatHandler:
@pytest.mark.asyncio
async def test_setup_initializes_clients(self):
"""Test that setup initializes all required clients."""
with patch("chat_handler.EmbeddingsClient") as emb_cls, \
patch("chat_handler.RerankerClient") as rer_cls, \
patch("chat_handler.LLMClient") as llm_cls, \
patch("chat_handler.TTSClient") as tts_cls, \
patch("chat_handler.MilvusClient") as mil_cls:
with (
patch("chat_handler.EmbeddingsClient") as emb_cls,
patch("chat_handler.RerankerClient") as rer_cls,
patch("chat_handler.LLMClient") as llm_cls,
patch("chat_handler.TTSClient") as tts_cls,
patch("chat_handler.MilvusClient") as mil_cls,
):
mil_cls.return_value.connect = AsyncMock()
handler = ChatHandler()