""" Milvus vector database client. """ import logging from typing import Optional from pymilvus import Collection, connections, utility from handler_base.config import Settings from handler_base.telemetry import create_span logger = logging.getLogger(__name__) class MilvusClient: """ Client for Milvus vector database. Usage: client = MilvusClient() await client.connect() results = await client.search(embedding, limit=10) """ def __init__(self, settings: Optional[Settings] = None): self.settings = settings or Settings() self._connected = False self._collection: Optional[Collection] = None async def connect(self, collection_name: Optional[str] = None) -> None: """ Connect to Milvus and load collection. Args: collection_name: Collection to use (defaults to settings) """ collection_name = collection_name or self.settings.milvus_collection connections.connect( alias="default", host=self.settings.milvus_host, port=self.settings.milvus_port, ) if utility.has_collection(collection_name): self._collection = Collection(collection_name) self._collection.load() logger.info(f"Connected to Milvus collection: {collection_name}") else: logger.warning(f"Collection {collection_name} does not exist") self._connected = True async def close(self) -> None: """Close Milvus connection.""" if self._collection: self._collection.release() connections.disconnect("default") self._connected = False logger.info("Disconnected from Milvus") async def search( self, embedding: list[float], limit: int = 10, output_fields: Optional[list[str]] = None, filter_expr: Optional[str] = None, ) -> list[dict]: """ Search for similar vectors. Args: embedding: Query embedding vector limit: Maximum number of results output_fields: Fields to return (default: all) filter_expr: Optional filter expression Returns: List of results with 'id', 'distance', and requested fields """ if not self._collection: raise RuntimeError("Not connected to collection") with create_span("milvus.search") as span: if span: span.set_attribute("milvus.collection", self._collection.name) span.set_attribute("milvus.limit", limit) search_params = {"metric_type": "COSINE", "params": {"nprobe": 10}} results = self._collection.search( data=[embedding], anns_field="embedding", param=search_params, limit=limit, output_fields=output_fields, expr=filter_expr, ) # Convert to list of dicts hits = [] for hit in results[0]: item = { "id": hit.id, "distance": hit.distance, "score": 1 - hit.distance, # Convert distance to similarity } # Add output fields if output_fields: for field in output_fields: if hasattr(hit.entity, field): item[field] = getattr(hit.entity, field) hits.append(item) if span: span.set_attribute("milvus.results", len(hits)) return hits async def search_with_texts( self, embedding: list[float], limit: int = 10, text_field: str = "text", metadata_fields: Optional[list[str]] = None, ) -> list[dict]: """ Search and return text content with metadata. Args: embedding: Query embedding limit: Maximum results text_field: Name of text field in collection metadata_fields: Additional metadata fields to return Returns: List of results with text and metadata """ output_fields = [text_field] if metadata_fields: output_fields.extend(metadata_fields) return await self.search(embedding, limit, output_fields) async def insert( self, embeddings: list[list[float]], data: list[dict], ) -> list[int]: """ Insert vectors with data into the collection. Args: embeddings: List of embedding vectors data: List of dicts with field values Returns: List of inserted IDs """ if not self._collection: raise RuntimeError("Not connected to collection") with create_span("milvus.insert") as span: if span: span.set_attribute("milvus.collection", self._collection.name) span.set_attribute("milvus.count", len(embeddings)) # Build insert data insert_data = [embeddings] for field in self._collection.schema.fields: if field.name not in ("id", "embedding"): field_values = [d.get(field.name) for d in data] insert_data.append(field_values) result = self._collection.insert(insert_data) self._collection.flush() return result.primary_keys def health(self) -> bool: """Check if connected to Milvus.""" return self._connected and utility.get_connection_addr("default") is not None