- Add py.typed marker for PEP 561 type hint support - Add ray_utils module for Ray handle detection and caching - Update all clients (Embeddings, LLM, TTS, STT, Reranker) to use Ray handles when running inside Ray cluster for faster internal calls - Add .pre-commit-config.yaml with ruff and standard hooks - Add pre-commit and ray[serve] to optional dependencies - Bump ruff version to 0.4.0
169 lines
5.4 KiB
Python
169 lines
5.4 KiB
Python
"""
|
|
Reranker service client (Infinity/BGE Reranker).
|
|
|
|
Supports both HTTP (external) and Ray handles (internal) for optimal performance.
|
|
"""
|
|
|
|
import logging
|
|
from typing import Any, Optional
|
|
|
|
import httpx
|
|
|
|
from handler_base.config import Settings
|
|
from handler_base.ray_utils import get_ray_handle
|
|
from handler_base.telemetry import create_span
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class RerankerClient:
|
|
"""
|
|
Client for the reranker service (Infinity with BGE Reranker).
|
|
|
|
When running inside Ray, automatically uses Ray handles for faster
|
|
internal communication. Falls back to HTTP for external calls.
|
|
|
|
Usage:
|
|
client = RerankerClient()
|
|
reranked = await client.rerank("query", ["doc1", "doc2"])
|
|
"""
|
|
|
|
# Ray Serve deployment configuration
|
|
RAY_DEPLOYMENT_NAME = "RerankerDeployment"
|
|
RAY_APP_NAME = "reranker"
|
|
|
|
def __init__(self, settings: Optional[Settings] = None):
|
|
self.settings = settings or Settings()
|
|
self._client = httpx.AsyncClient(
|
|
base_url=self.settings.reranker_url,
|
|
timeout=self.settings.http_timeout,
|
|
)
|
|
self._ray_handle: Optional[Any] = None
|
|
self._ray_checked = False
|
|
|
|
def _get_ray_handle(self) -> Optional[Any]:
|
|
"""Get Ray handle, checking only once."""
|
|
if not self._ray_checked:
|
|
self._ray_handle = get_ray_handle(self.RAY_DEPLOYMENT_NAME, self.RAY_APP_NAME)
|
|
self._ray_checked = True
|
|
return self._ray_handle
|
|
|
|
async def close(self) -> None:
|
|
"""Close the HTTP client."""
|
|
await self._client.aclose()
|
|
|
|
async def rerank(
|
|
self,
|
|
query: str,
|
|
documents: list[str],
|
|
top_k: Optional[int] = None,
|
|
) -> list[dict]:
|
|
"""
|
|
Rerank documents based on relevance to query.
|
|
|
|
Args:
|
|
query: Query text
|
|
documents: List of documents to rerank
|
|
top_k: Number of top results to return (default: all)
|
|
|
|
Returns:
|
|
List of dicts with 'index', 'score', and 'document' keys,
|
|
sorted by relevance score descending.
|
|
"""
|
|
with create_span("reranker.rerank") as span:
|
|
if span:
|
|
span.set_attribute("reranker.num_documents", len(documents))
|
|
if top_k:
|
|
span.set_attribute("reranker.top_k", top_k)
|
|
|
|
payload = {
|
|
"query": query,
|
|
"documents": documents,
|
|
}
|
|
if top_k:
|
|
payload["top_n"] = top_k
|
|
|
|
# Try Ray handle first (faster internal path)
|
|
handle = self._get_ray_handle()
|
|
if handle:
|
|
try:
|
|
if span:
|
|
span.set_attribute("reranker.transport", "ray")
|
|
results = await handle.rerank.remote(query, documents, top_k)
|
|
# Enrich with original documents
|
|
enriched = []
|
|
for r in results:
|
|
idx = r.get("index", 0)
|
|
enriched.append(
|
|
{
|
|
"index": idx,
|
|
"score": r.get("relevance_score", r.get("score", 0)),
|
|
"document": documents[idx] if idx < len(documents) else "",
|
|
}
|
|
)
|
|
return enriched
|
|
except Exception as e:
|
|
logger.warning(f"Ray handle failed, falling back to HTTP: {e}")
|
|
|
|
# HTTP fallback
|
|
if span:
|
|
span.set_attribute("reranker.transport", "http")
|
|
|
|
response = await self._client.post("/rerank", json=payload)
|
|
response.raise_for_status()
|
|
|
|
result = response.json()
|
|
results = result.get("results", [])
|
|
|
|
# Enrich with original documents
|
|
enriched = []
|
|
for r in results:
|
|
idx = r.get("index", 0)
|
|
enriched.append(
|
|
{
|
|
"index": idx,
|
|
"score": r.get("relevance_score", r.get("score", 0)),
|
|
"document": documents[idx] if idx < len(documents) else "",
|
|
}
|
|
)
|
|
|
|
return enriched
|
|
|
|
async def rerank_with_metadata(
|
|
self,
|
|
query: str,
|
|
documents: list[dict],
|
|
text_key: str = "text",
|
|
top_k: Optional[int] = None,
|
|
) -> list[dict]:
|
|
"""
|
|
Rerank documents with metadata, preserving metadata in results.
|
|
|
|
Args:
|
|
query: Query text
|
|
documents: List of dicts with text and metadata
|
|
text_key: Key containing text in each document dict
|
|
top_k: Number of top results to return
|
|
|
|
Returns:
|
|
Reranked documents with original metadata preserved.
|
|
"""
|
|
texts = [d.get(text_key, "") for d in documents]
|
|
reranked = await self.rerank(query, texts, top_k)
|
|
|
|
# Merge back metadata
|
|
for r in reranked:
|
|
idx = r["index"]
|
|
if idx < len(documents):
|
|
r["metadata"] = {k: v for k, v in documents[idx].items() if k != text_key}
|
|
|
|
return reranked
|
|
|
|
async def health(self) -> bool:
|
|
"""Check if the reranker service is healthy."""
|
|
try:
|
|
response = await self._client.get("/health")
|
|
return response.status_code == 200
|
|
except Exception:
|
|
return False
|