Files
handler-base/handler_base/clients/embeddings.py
Billy D. 408f31e56d
All checks were successful
CI / Test (push) Successful in 4m8s
CI / Lint (push) Successful in 4m19s
CI / Release (push) Successful in 58s
CI / Notify (push) Successful in 2s
feat: add py.typed, Ray handles for clients, and pre-commit config
- Add py.typed marker for PEP 561 type hint support
- Add ray_utils module for Ray handle detection and caching
- Update all clients (Embeddings, LLM, TTS, STT, Reranker) to use
  Ray handles when running inside Ray cluster for faster internal calls
- Add .pre-commit-config.yaml with ruff and standard hooks
- Add pre-commit and ray[serve] to optional dependencies
- Bump ruff version to 0.4.0
2026-02-02 09:08:43 -05:00

129 lines
4.0 KiB
Python

"""
Embeddings service client (Infinity/BGE).
Supports both HTTP (external) and Ray handles (internal) for optimal performance.
"""
import logging
from typing import Any, Optional
import httpx
from handler_base.config import EmbeddingsSettings
from handler_base.ray_utils import get_ray_handle
from handler_base.telemetry import create_span
logger = logging.getLogger(__name__)
class EmbeddingsClient:
"""
Client for the embeddings service (Infinity with BGE models).
When running inside Ray, automatically uses Ray handles for faster
internal communication. Falls back to HTTP for external calls.
Usage:
client = EmbeddingsClient()
embeddings = await client.embed(["Hello world"])
"""
# Ray Serve deployment configuration
RAY_DEPLOYMENT_NAME = "EmbeddingsDeployment"
RAY_APP_NAME = "embeddings"
def __init__(self, settings: Optional[EmbeddingsSettings] = None):
self.settings = settings or EmbeddingsSettings()
self._client = httpx.AsyncClient(
base_url=self.settings.embeddings_url,
timeout=self.settings.http_timeout,
)
self._ray_handle: Optional[Any] = None
self._ray_checked = False
def _get_ray_handle(self) -> Optional[Any]:
"""Get Ray handle, checking only once."""
if not self._ray_checked:
self._ray_handle = get_ray_handle(self.RAY_DEPLOYMENT_NAME, self.RAY_APP_NAME)
self._ray_checked = True
return self._ray_handle
async def close(self) -> None:
"""Close the HTTP client."""
await self._client.aclose()
async def embed(
self,
texts: list[str],
model: Optional[str] = None,
) -> list[list[float]]:
"""
Generate embeddings for a list of texts.
Args:
texts: List of texts to embed
model: Model name (defaults to settings)
Returns:
List of embedding vectors
"""
model = model or self.settings.embeddings_model
with create_span("embeddings.embed") as span:
if span:
span.set_attribute("embeddings.model", model)
span.set_attribute("embeddings.batch_size", len(texts))
# Try Ray handle first (faster internal path)
handle = self._get_ray_handle()
if handle:
try:
if span:
span.set_attribute("embeddings.transport", "ray")
result = await handle.embed.remote(texts, model)
if span and result:
span.set_attribute("embeddings.dimensions", len(result[0]) if result else 0)
return result
except Exception as e:
logger.warning(f"Ray handle failed, falling back to HTTP: {e}")
# HTTP fallback
if span:
span.set_attribute("embeddings.transport", "http")
response = await self._client.post(
"/embeddings",
json={"input": texts, "model": model},
)
response.raise_for_status()
result = response.json()
embeddings = [d["embedding"] for d in result.get("data", [])]
if span:
span.set_attribute("embeddings.dimensions", len(embeddings[0]) if embeddings else 0)
return embeddings
async def embed_single(self, text: str, model: Optional[str] = None) -> list[float]:
"""
Generate embedding for a single text.
Args:
text: Text to embed
model: Model name (defaults to settings)
Returns:
Embedding vector
"""
embeddings = await self.embed([text], model)
return embeddings[0] if embeddings else []
async def health(self) -> bool:
"""Check if the embeddings service is healthy."""
try:
response = await self._client.get("/health")
return response.status_code == 200
except Exception:
return False