fix: ruff formatting and allow-direct-references for handler-base dep
This commit is contained in:
@@ -11,6 +11,7 @@ Text-based chat pipeline using handler-base:
|
|||||||
6. Optionally synthesize speech with XTTS
|
6. Optionally synthesize speech with XTTS
|
||||||
7. Publish result to NATS "ai.chat.response.{request_id}"
|
7. Publish result to NATS "ai.chat.response.{request_id}"
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
@@ -138,9 +139,7 @@ class ChatHandler(Handler):
|
|||||||
context = self._build_context(reranked)
|
context = self._build_context(reranked)
|
||||||
|
|
||||||
# 5. Generate LLM response
|
# 5. Generate LLM response
|
||||||
response_text = await self._generate_response(
|
response_text = await self._generate_response(query, context, system_prompt)
|
||||||
query, context, system_prompt
|
|
||||||
)
|
|
||||||
|
|
||||||
# 6. Optionally synthesize speech
|
# 6. Optionally synthesize speech
|
||||||
audio_b64 = None
|
audio_b64 = None
|
||||||
@@ -155,8 +154,7 @@ class ChatHandler(Handler):
|
|||||||
|
|
||||||
if self.chat_settings.include_sources:
|
if self.chat_settings.include_sources:
|
||||||
result["sources"] = [
|
result["sources"] = [
|
||||||
{"text": d["document"][:200], "score": d["score"]}
|
{"text": d["document"][:200], "score": d["score"]} for d in reranked[:3]
|
||||||
for d in reranked[:3]
|
|
||||||
]
|
]
|
||||||
|
|
||||||
if audio_b64:
|
if audio_b64:
|
||||||
@@ -175,9 +173,7 @@ class ChatHandler(Handler):
|
|||||||
with create_span("chat.embedding"):
|
with create_span("chat.embedding"):
|
||||||
return await self.embeddings.embed_single(text)
|
return await self.embeddings.embed_single(text)
|
||||||
|
|
||||||
async def _search_context(
|
async def _search_context(self, embedding: list[float], collection: str) -> list[dict]:
|
||||||
self, embedding: list[float], collection: str
|
|
||||||
) -> list[dict]:
|
|
||||||
"""Search Milvus for relevant documents."""
|
"""Search Milvus for relevant documents."""
|
||||||
with create_span("chat.search"):
|
with create_span("chat.search"):
|
||||||
return await self.milvus.search_with_texts(
|
return await self.milvus.search_with_texts(
|
||||||
@@ -187,9 +183,7 @@ class ChatHandler(Handler):
|
|||||||
metadata_fields=["source", "title"],
|
metadata_fields=["source", "title"],
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _rerank_documents(
|
async def _rerank_documents(self, query: str, documents: list[dict]) -> list[dict]:
|
||||||
self, query: str, documents: list[dict]
|
|
||||||
) -> list[dict]:
|
|
||||||
"""Rerank documents by relevance to query."""
|
"""Rerank documents by relevance to query."""
|
||||||
with create_span("chat.rerank"):
|
with create_span("chat.rerank"):
|
||||||
texts = [d.get("text", "") for d in documents]
|
texts = [d.get("text", "") for d in documents]
|
||||||
|
|||||||
@@ -22,6 +22,9 @@ dev = [
|
|||||||
requires = ["hatchling"]
|
requires = ["hatchling"]
|
||||||
build-backend = "hatchling.build"
|
build-backend = "hatchling.build"
|
||||||
|
|
||||||
|
[tool.hatch.metadata]
|
||||||
|
allow-direct-references = true
|
||||||
|
|
||||||
[tool.hatch.build.targets.wheel]
|
[tool.hatch.build.targets.wheel]
|
||||||
packages = ["."]
|
packages = ["."]
|
||||||
only-include = ["chat_handler.py"]
|
only-include = ["chat_handler.py"]
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
"""
|
"""
|
||||||
Pytest configuration and fixtures for chat-handler tests.
|
Pytest configuration and fixtures for chat-handler tests.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
"""
|
"""
|
||||||
Unit tests for ChatHandler.
|
Unit tests for ChatHandler.
|
||||||
"""
|
"""
|
||||||
import base64
|
|
||||||
import pytest
|
import pytest
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
from chat_handler import ChatHandler, ChatSettings
|
from chat_handler import ChatHandler, ChatSettings
|
||||||
|
|
||||||
@@ -42,12 +42,13 @@ class TestChatHandler:
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def handler(self):
|
def handler(self):
|
||||||
"""Create handler with mocked clients."""
|
"""Create handler with mocked clients."""
|
||||||
with patch("chat_handler.EmbeddingsClient"), \
|
with (
|
||||||
patch("chat_handler.RerankerClient"), \
|
patch("chat_handler.EmbeddingsClient"),
|
||||||
patch("chat_handler.LLMClient"), \
|
patch("chat_handler.RerankerClient"),
|
||||||
patch("chat_handler.TTSClient"), \
|
patch("chat_handler.LLMClient"),
|
||||||
patch("chat_handler.MilvusClient"):
|
patch("chat_handler.TTSClient"),
|
||||||
|
patch("chat_handler.MilvusClient"),
|
||||||
|
):
|
||||||
handler = ChatHandler()
|
handler = ChatHandler()
|
||||||
|
|
||||||
# Setup mock clients
|
# Setup mock clients
|
||||||
@@ -63,12 +64,13 @@ class TestChatHandler:
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def handler_with_tts(self):
|
def handler_with_tts(self):
|
||||||
"""Create handler with TTS enabled."""
|
"""Create handler with TTS enabled."""
|
||||||
with patch("chat_handler.EmbeddingsClient"), \
|
with (
|
||||||
patch("chat_handler.RerankerClient"), \
|
patch("chat_handler.EmbeddingsClient"),
|
||||||
patch("chat_handler.LLMClient"), \
|
patch("chat_handler.RerankerClient"),
|
||||||
patch("chat_handler.TTSClient"), \
|
patch("chat_handler.LLMClient"),
|
||||||
patch("chat_handler.MilvusClient"):
|
patch("chat_handler.TTSClient"),
|
||||||
|
patch("chat_handler.MilvusClient"),
|
||||||
|
):
|
||||||
handler = ChatHandler()
|
handler = ChatHandler()
|
||||||
handler.chat_settings.enable_tts = True
|
handler.chat_settings.enable_tts = True
|
||||||
|
|
||||||
@@ -211,12 +213,13 @@ class TestChatHandler:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_setup_initializes_clients(self):
|
async def test_setup_initializes_clients(self):
|
||||||
"""Test that setup initializes all required clients."""
|
"""Test that setup initializes all required clients."""
|
||||||
with patch("chat_handler.EmbeddingsClient") as emb_cls, \
|
with (
|
||||||
patch("chat_handler.RerankerClient") as rer_cls, \
|
patch("chat_handler.EmbeddingsClient") as emb_cls,
|
||||||
patch("chat_handler.LLMClient") as llm_cls, \
|
patch("chat_handler.RerankerClient") as rer_cls,
|
||||||
patch("chat_handler.TTSClient") as tts_cls, \
|
patch("chat_handler.LLMClient") as llm_cls,
|
||||||
patch("chat_handler.MilvusClient") as mil_cls:
|
patch("chat_handler.TTSClient") as tts_cls,
|
||||||
|
patch("chat_handler.MilvusClient") as mil_cls,
|
||||||
|
):
|
||||||
mil_cls.return_value.connect = AsyncMock()
|
mil_cls.return_value.connect = AsyncMock()
|
||||||
|
|
||||||
handler = ChatHandler()
|
handler = ChatHandler()
|
||||||
|
|||||||
Reference in New Issue
Block a user