fix: ruff formatting and allow-direct-references for handler-base dep
This commit is contained in:
@@ -11,6 +11,7 @@ Text-based chat pipeline using handler-base:
|
||||
6. Optionally synthesize speech with XTTS
|
||||
7. Publish result to NATS "ai.chat.response.{request_id}"
|
||||
"""
|
||||
|
||||
import base64
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
@@ -138,9 +139,7 @@ class ChatHandler(Handler):
|
||||
context = self._build_context(reranked)
|
||||
|
||||
# 5. Generate LLM response
|
||||
response_text = await self._generate_response(
|
||||
query, context, system_prompt
|
||||
)
|
||||
response_text = await self._generate_response(query, context, system_prompt)
|
||||
|
||||
# 6. Optionally synthesize speech
|
||||
audio_b64 = None
|
||||
@@ -155,8 +154,7 @@ class ChatHandler(Handler):
|
||||
|
||||
if self.chat_settings.include_sources:
|
||||
result["sources"] = [
|
||||
{"text": d["document"][:200], "score": d["score"]}
|
||||
for d in reranked[:3]
|
||||
{"text": d["document"][:200], "score": d["score"]} for d in reranked[:3]
|
||||
]
|
||||
|
||||
if audio_b64:
|
||||
@@ -175,9 +173,7 @@ class ChatHandler(Handler):
|
||||
with create_span("chat.embedding"):
|
||||
return await self.embeddings.embed_single(text)
|
||||
|
||||
async def _search_context(
|
||||
self, embedding: list[float], collection: str
|
||||
) -> list[dict]:
|
||||
async def _search_context(self, embedding: list[float], collection: str) -> list[dict]:
|
||||
"""Search Milvus for relevant documents."""
|
||||
with create_span("chat.search"):
|
||||
return await self.milvus.search_with_texts(
|
||||
@@ -187,9 +183,7 @@ class ChatHandler(Handler):
|
||||
metadata_fields=["source", "title"],
|
||||
)
|
||||
|
||||
async def _rerank_documents(
|
||||
self, query: str, documents: list[dict]
|
||||
) -> list[dict]:
|
||||
async def _rerank_documents(self, query: str, documents: list[dict]) -> list[dict]:
|
||||
"""Rerank documents by relevance to query."""
|
||||
with create_span("chat.rerank"):
|
||||
texts = [d.get("text", "") for d in documents]
|
||||
|
||||
@@ -22,6 +22,9 @@ dev = [
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[tool.hatch.metadata]
|
||||
allow-direct-references = true
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["."]
|
||||
only-include = ["chat_handler.py"]
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
"""
|
||||
Pytest configuration and fixtures for chat-handler tests.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
"""
|
||||
Unit tests for ChatHandler.
|
||||
"""
|
||||
import base64
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from chat_handler import ChatHandler, ChatSettings
|
||||
|
||||
@@ -42,12 +42,13 @@ class TestChatHandler:
|
||||
@pytest.fixture
|
||||
def handler(self):
|
||||
"""Create handler with mocked clients."""
|
||||
with patch("chat_handler.EmbeddingsClient"), \
|
||||
patch("chat_handler.RerankerClient"), \
|
||||
patch("chat_handler.LLMClient"), \
|
||||
patch("chat_handler.TTSClient"), \
|
||||
patch("chat_handler.MilvusClient"):
|
||||
|
||||
with (
|
||||
patch("chat_handler.EmbeddingsClient"),
|
||||
patch("chat_handler.RerankerClient"),
|
||||
patch("chat_handler.LLMClient"),
|
||||
patch("chat_handler.TTSClient"),
|
||||
patch("chat_handler.MilvusClient"),
|
||||
):
|
||||
handler = ChatHandler()
|
||||
|
||||
# Setup mock clients
|
||||
@@ -63,12 +64,13 @@ class TestChatHandler:
|
||||
@pytest.fixture
|
||||
def handler_with_tts(self):
|
||||
"""Create handler with TTS enabled."""
|
||||
with patch("chat_handler.EmbeddingsClient"), \
|
||||
patch("chat_handler.RerankerClient"), \
|
||||
patch("chat_handler.LLMClient"), \
|
||||
patch("chat_handler.TTSClient"), \
|
||||
patch("chat_handler.MilvusClient"):
|
||||
|
||||
with (
|
||||
patch("chat_handler.EmbeddingsClient"),
|
||||
patch("chat_handler.RerankerClient"),
|
||||
patch("chat_handler.LLMClient"),
|
||||
patch("chat_handler.TTSClient"),
|
||||
patch("chat_handler.MilvusClient"),
|
||||
):
|
||||
handler = ChatHandler()
|
||||
handler.chat_settings.enable_tts = True
|
||||
|
||||
@@ -211,12 +213,13 @@ class TestChatHandler:
|
||||
@pytest.mark.asyncio
|
||||
async def test_setup_initializes_clients(self):
|
||||
"""Test that setup initializes all required clients."""
|
||||
with patch("chat_handler.EmbeddingsClient") as emb_cls, \
|
||||
patch("chat_handler.RerankerClient") as rer_cls, \
|
||||
patch("chat_handler.LLMClient") as llm_cls, \
|
||||
patch("chat_handler.TTSClient") as tts_cls, \
|
||||
patch("chat_handler.MilvusClient") as mil_cls:
|
||||
|
||||
with (
|
||||
patch("chat_handler.EmbeddingsClient") as emb_cls,
|
||||
patch("chat_handler.RerankerClient") as rer_cls,
|
||||
patch("chat_handler.LLMClient") as llm_cls,
|
||||
patch("chat_handler.TTSClient") as tts_cls,
|
||||
patch("chat_handler.MilvusClient") as mil_cls,
|
||||
):
|
||||
mil_cls.return_value.connect = AsyncMock()
|
||||
|
||||
handler = ChatHandler()
|
||||
|
||||
Reference in New Issue
Block a user