feat: Add handler-base library for NATS AI/ML services
- Handler base class with graceful shutdown and signal handling - NATSClient with JetStream and msgpack serialization - Pydantic Settings for environment configuration - HealthServer for Kubernetes probes - OpenTelemetry telemetry setup - Service clients: STT, TTS, LLM, Embeddings, Reranker, Milvus
This commit is contained in:
91
handler_base/clients/embeddings.py
Normal file
91
handler_base/clients/embeddings.py
Normal file
@@ -0,0 +1,91 @@
|
||||
"""
|
||||
Embeddings service client (Infinity/BGE).
|
||||
"""
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from handler_base.config import EmbeddingsSettings
|
||||
from handler_base.telemetry import create_span
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EmbeddingsClient:
|
||||
"""
|
||||
Client for the embeddings service (Infinity with BGE models).
|
||||
|
||||
Usage:
|
||||
client = EmbeddingsClient()
|
||||
embeddings = await client.embed(["Hello world"])
|
||||
"""
|
||||
|
||||
def __init__(self, settings: Optional[EmbeddingsSettings] = None):
|
||||
self.settings = settings or EmbeddingsSettings()
|
||||
self._client = httpx.AsyncClient(
|
||||
base_url=self.settings.embeddings_url,
|
||||
timeout=self.settings.http_timeout,
|
||||
)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the HTTP client."""
|
||||
await self._client.aclose()
|
||||
|
||||
async def embed(
|
||||
self,
|
||||
texts: list[str],
|
||||
model: Optional[str] = None,
|
||||
) -> list[list[float]]:
|
||||
"""
|
||||
Generate embeddings for a list of texts.
|
||||
|
||||
Args:
|
||||
texts: List of texts to embed
|
||||
model: Model name (defaults to settings)
|
||||
|
||||
Returns:
|
||||
List of embedding vectors
|
||||
"""
|
||||
model = model or self.settings.embeddings_model
|
||||
|
||||
with create_span("embeddings.embed") as span:
|
||||
if span:
|
||||
span.set_attribute("embeddings.model", model)
|
||||
span.set_attribute("embeddings.batch_size", len(texts))
|
||||
|
||||
response = await self._client.post(
|
||||
"/embeddings",
|
||||
json={"input": texts, "model": model},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
result = response.json()
|
||||
embeddings = [d["embedding"] for d in result.get("data", [])]
|
||||
|
||||
if span:
|
||||
span.set_attribute("embeddings.dimensions", len(embeddings[0]) if embeddings else 0)
|
||||
|
||||
return embeddings
|
||||
|
||||
async def embed_single(self, text: str, model: Optional[str] = None) -> list[float]:
|
||||
"""
|
||||
Generate embedding for a single text.
|
||||
|
||||
Args:
|
||||
text: Text to embed
|
||||
model: Model name (defaults to settings)
|
||||
|
||||
Returns:
|
||||
Embedding vector
|
||||
"""
|
||||
embeddings = await self.embed([text], model)
|
||||
return embeddings[0] if embeddings else []
|
||||
|
||||
async def health(self) -> bool:
|
||||
"""Check if the embeddings service is healthy."""
|
||||
try:
|
||||
response = await self._client.get("/health")
|
||||
return response.status_code == 200
|
||||
except Exception:
|
||||
return False
|
||||
Reference in New Issue
Block a user