- Handler base class with graceful shutdown and signal handling - NATSClient with JetStream and msgpack serialization - Pydantic Settings for environment configuration - HealthServer for Kubernetes probes - OpenTelemetry telemetry setup - Service clients: STT, TTS, LLM, Embeddings, Reranker, Milvus
183 lines
5.8 KiB
Python
183 lines
5.8 KiB
Python
"""
|
|
Milvus vector database client.
|
|
"""
|
|
import logging
|
|
from typing import Optional, Any
|
|
|
|
from pymilvus import connections, Collection, utility
|
|
|
|
from handler_base.config import Settings
|
|
from handler_base.telemetry import create_span
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class MilvusClient:
|
|
"""
|
|
Client for Milvus vector database.
|
|
|
|
Usage:
|
|
client = MilvusClient()
|
|
await client.connect()
|
|
results = await client.search(embedding, limit=10)
|
|
"""
|
|
|
|
def __init__(self, settings: Optional[Settings] = None):
|
|
self.settings = settings or Settings()
|
|
self._connected = False
|
|
self._collection: Optional[Collection] = None
|
|
|
|
async def connect(self, collection_name: Optional[str] = None) -> None:
|
|
"""
|
|
Connect to Milvus and load collection.
|
|
|
|
Args:
|
|
collection_name: Collection to use (defaults to settings)
|
|
"""
|
|
collection_name = collection_name or self.settings.milvus_collection
|
|
|
|
connections.connect(
|
|
alias="default",
|
|
host=self.settings.milvus_host,
|
|
port=self.settings.milvus_port,
|
|
)
|
|
|
|
if utility.has_collection(collection_name):
|
|
self._collection = Collection(collection_name)
|
|
self._collection.load()
|
|
logger.info(f"Connected to Milvus collection: {collection_name}")
|
|
else:
|
|
logger.warning(f"Collection {collection_name} does not exist")
|
|
|
|
self._connected = True
|
|
|
|
async def close(self) -> None:
|
|
"""Close Milvus connection."""
|
|
if self._collection:
|
|
self._collection.release()
|
|
connections.disconnect("default")
|
|
self._connected = False
|
|
logger.info("Disconnected from Milvus")
|
|
|
|
async def search(
|
|
self,
|
|
embedding: list[float],
|
|
limit: int = 10,
|
|
output_fields: Optional[list[str]] = None,
|
|
filter_expr: Optional[str] = None,
|
|
) -> list[dict]:
|
|
"""
|
|
Search for similar vectors.
|
|
|
|
Args:
|
|
embedding: Query embedding vector
|
|
limit: Maximum number of results
|
|
output_fields: Fields to return (default: all)
|
|
filter_expr: Optional filter expression
|
|
|
|
Returns:
|
|
List of results with 'id', 'distance', and requested fields
|
|
"""
|
|
if not self._collection:
|
|
raise RuntimeError("Not connected to collection")
|
|
|
|
with create_span("milvus.search") as span:
|
|
if span:
|
|
span.set_attribute("milvus.collection", self._collection.name)
|
|
span.set_attribute("milvus.limit", limit)
|
|
|
|
search_params = {"metric_type": "COSINE", "params": {"nprobe": 10}}
|
|
|
|
results = self._collection.search(
|
|
data=[embedding],
|
|
anns_field="embedding",
|
|
param=search_params,
|
|
limit=limit,
|
|
output_fields=output_fields,
|
|
expr=filter_expr,
|
|
)
|
|
|
|
# Convert to list of dicts
|
|
hits = []
|
|
for hit in results[0]:
|
|
item = {
|
|
"id": hit.id,
|
|
"distance": hit.distance,
|
|
"score": 1 - hit.distance, # Convert distance to similarity
|
|
}
|
|
# Add output fields
|
|
if output_fields:
|
|
for field in output_fields:
|
|
if hasattr(hit.entity, field):
|
|
item[field] = getattr(hit.entity, field)
|
|
hits.append(item)
|
|
|
|
if span:
|
|
span.set_attribute("milvus.results", len(hits))
|
|
|
|
return hits
|
|
|
|
async def search_with_texts(
|
|
self,
|
|
embedding: list[float],
|
|
limit: int = 10,
|
|
text_field: str = "text",
|
|
metadata_fields: Optional[list[str]] = None,
|
|
) -> list[dict]:
|
|
"""
|
|
Search and return text content with metadata.
|
|
|
|
Args:
|
|
embedding: Query embedding
|
|
limit: Maximum results
|
|
text_field: Name of text field in collection
|
|
metadata_fields: Additional metadata fields to return
|
|
|
|
Returns:
|
|
List of results with text and metadata
|
|
"""
|
|
output_fields = [text_field]
|
|
if metadata_fields:
|
|
output_fields.extend(metadata_fields)
|
|
|
|
return await self.search(embedding, limit, output_fields)
|
|
|
|
async def insert(
|
|
self,
|
|
embeddings: list[list[float]],
|
|
data: list[dict],
|
|
) -> list[int]:
|
|
"""
|
|
Insert vectors with data into the collection.
|
|
|
|
Args:
|
|
embeddings: List of embedding vectors
|
|
data: List of dicts with field values
|
|
|
|
Returns:
|
|
List of inserted IDs
|
|
"""
|
|
if not self._collection:
|
|
raise RuntimeError("Not connected to collection")
|
|
|
|
with create_span("milvus.insert") as span:
|
|
if span:
|
|
span.set_attribute("milvus.collection", self._collection.name)
|
|
span.set_attribute("milvus.count", len(embeddings))
|
|
|
|
# Build insert data
|
|
insert_data = [embeddings]
|
|
for field in self._collection.schema.fields:
|
|
if field.name not in ("id", "embedding"):
|
|
field_values = [d.get(field.name) for d in data]
|
|
insert_data.append(field_values)
|
|
|
|
result = self._collection.insert(insert_data)
|
|
self._collection.flush()
|
|
|
|
return result.primary_keys
|
|
|
|
def health(self) -> bool:
|
|
"""Check if connected to Milvus."""
|
|
return self._connected and utility.get_connection_addr("default") is not None
|