Files
handler-base/handler_base/clients/reranker.py
Billy D. dbf1a93141
All checks were successful
CI / Lint (push) Successful in 52s
CI / Test (push) Successful in 1m1s
CI / Release (push) Successful in 5s
CI / Notify (push) Successful in 1s
fix: auto-fix ruff linting errors and remove unsupported upload-artifact
2026-02-02 08:34:00 -05:00

124 lines
3.6 KiB
Python

"""
Reranker service client (Infinity/BGE Reranker).
"""
import logging
from typing import Optional
import httpx
from handler_base.config import Settings
from handler_base.telemetry import create_span
logger = logging.getLogger(__name__)
class RerankerClient:
"""
Client for the reranker service (Infinity with BGE Reranker).
Usage:
client = RerankerClient()
reranked = await client.rerank("query", ["doc1", "doc2"])
"""
def __init__(self, settings: Optional[Settings] = None):
self.settings = settings or Settings()
self._client = httpx.AsyncClient(
base_url=self.settings.reranker_url,
timeout=self.settings.http_timeout,
)
async def close(self) -> None:
"""Close the HTTP client."""
await self._client.aclose()
async def rerank(
self,
query: str,
documents: list[str],
top_k: Optional[int] = None,
) -> list[dict]:
"""
Rerank documents based on relevance to query.
Args:
query: Query text
documents: List of documents to rerank
top_k: Number of top results to return (default: all)
Returns:
List of dicts with 'index', 'score', and 'document' keys,
sorted by relevance score descending.
"""
with create_span("reranker.rerank") as span:
if span:
span.set_attribute("reranker.num_documents", len(documents))
if top_k:
span.set_attribute("reranker.top_k", top_k)
payload = {
"query": query,
"documents": documents,
}
if top_k:
payload["top_n"] = top_k
response = await self._client.post("/rerank", json=payload)
response.raise_for_status()
result = response.json()
results = result.get("results", [])
# Enrich with original documents
enriched = []
for r in results:
idx = r.get("index", 0)
enriched.append(
{
"index": idx,
"score": r.get("relevance_score", r.get("score", 0)),
"document": documents[idx] if idx < len(documents) else "",
}
)
return enriched
async def rerank_with_metadata(
self,
query: str,
documents: list[dict],
text_key: str = "text",
top_k: Optional[int] = None,
) -> list[dict]:
"""
Rerank documents with metadata, preserving metadata in results.
Args:
query: Query text
documents: List of dicts with text and metadata
text_key: Key containing text in each document dict
top_k: Number of top results to return
Returns:
Reranked documents with original metadata preserved.
"""
texts = [d.get(text_key, "") for d in documents]
reranked = await self.rerank(query, texts, top_k)
# Merge back metadata
for r in reranked:
idx = r["index"]
if idx < len(documents):
r["metadata"] = {k: v for k, v in documents[idx].items() if k != text_key}
return reranked
async def health(self) -> bool:
"""Check if the reranker service is healthy."""
try:
response = await self._client.get("/health")
return response.status_code == 200
except Exception:
return False