fix: auto-fix ruff linting errors and remove unsupported upload-artifact
This commit is contained in:
@@ -1,10 +1,11 @@
|
||||
"""
|
||||
Milvus vector database client.
|
||||
"""
|
||||
import logging
|
||||
from typing import Optional, Any
|
||||
|
||||
from pymilvus import connections, Collection, utility
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from pymilvus import Collection, connections, utility
|
||||
|
||||
from handler_base.config import Settings
|
||||
from handler_base.telemetry import create_span
|
||||
@@ -15,42 +16,42 @@ logger = logging.getLogger(__name__)
|
||||
class MilvusClient:
|
||||
"""
|
||||
Client for Milvus vector database.
|
||||
|
||||
|
||||
Usage:
|
||||
client = MilvusClient()
|
||||
await client.connect()
|
||||
results = await client.search(embedding, limit=10)
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, settings: Optional[Settings] = None):
|
||||
self.settings = settings or Settings()
|
||||
self._connected = False
|
||||
self._collection: Optional[Collection] = None
|
||||
|
||||
|
||||
async def connect(self, collection_name: Optional[str] = None) -> None:
|
||||
"""
|
||||
Connect to Milvus and load collection.
|
||||
|
||||
|
||||
Args:
|
||||
collection_name: Collection to use (defaults to settings)
|
||||
"""
|
||||
collection_name = collection_name or self.settings.milvus_collection
|
||||
|
||||
|
||||
connections.connect(
|
||||
alias="default",
|
||||
host=self.settings.milvus_host,
|
||||
port=self.settings.milvus_port,
|
||||
)
|
||||
|
||||
|
||||
if utility.has_collection(collection_name):
|
||||
self._collection = Collection(collection_name)
|
||||
self._collection.load()
|
||||
logger.info(f"Connected to Milvus collection: {collection_name}")
|
||||
else:
|
||||
logger.warning(f"Collection {collection_name} does not exist")
|
||||
|
||||
|
||||
self._connected = True
|
||||
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close Milvus connection."""
|
||||
if self._collection:
|
||||
@@ -58,7 +59,7 @@ class MilvusClient:
|
||||
connections.disconnect("default")
|
||||
self._connected = False
|
||||
logger.info("Disconnected from Milvus")
|
||||
|
||||
|
||||
async def search(
|
||||
self,
|
||||
embedding: list[float],
|
||||
@@ -68,26 +69,26 @@ class MilvusClient:
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Search for similar vectors.
|
||||
|
||||
|
||||
Args:
|
||||
embedding: Query embedding vector
|
||||
limit: Maximum number of results
|
||||
output_fields: Fields to return (default: all)
|
||||
filter_expr: Optional filter expression
|
||||
|
||||
|
||||
Returns:
|
||||
List of results with 'id', 'distance', and requested fields
|
||||
"""
|
||||
if not self._collection:
|
||||
raise RuntimeError("Not connected to collection")
|
||||
|
||||
|
||||
with create_span("milvus.search") as span:
|
||||
if span:
|
||||
span.set_attribute("milvus.collection", self._collection.name)
|
||||
span.set_attribute("milvus.limit", limit)
|
||||
|
||||
|
||||
search_params = {"metric_type": "COSINE", "params": {"nprobe": 10}}
|
||||
|
||||
|
||||
results = self._collection.search(
|
||||
data=[embedding],
|
||||
anns_field="embedding",
|
||||
@@ -96,7 +97,7 @@ class MilvusClient:
|
||||
output_fields=output_fields,
|
||||
expr=filter_expr,
|
||||
)
|
||||
|
||||
|
||||
# Convert to list of dicts
|
||||
hits = []
|
||||
for hit in results[0]:
|
||||
@@ -111,12 +112,12 @@ class MilvusClient:
|
||||
if hasattr(hit.entity, field):
|
||||
item[field] = getattr(hit.entity, field)
|
||||
hits.append(item)
|
||||
|
||||
|
||||
if span:
|
||||
span.set_attribute("milvus.results", len(hits))
|
||||
|
||||
|
||||
return hits
|
||||
|
||||
|
||||
async def search_with_texts(
|
||||
self,
|
||||
embedding: list[float],
|
||||
@@ -126,22 +127,22 @@ class MilvusClient:
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Search and return text content with metadata.
|
||||
|
||||
|
||||
Args:
|
||||
embedding: Query embedding
|
||||
limit: Maximum results
|
||||
text_field: Name of text field in collection
|
||||
metadata_fields: Additional metadata fields to return
|
||||
|
||||
|
||||
Returns:
|
||||
List of results with text and metadata
|
||||
"""
|
||||
output_fields = [text_field]
|
||||
if metadata_fields:
|
||||
output_fields.extend(metadata_fields)
|
||||
|
||||
|
||||
return await self.search(embedding, limit, output_fields)
|
||||
|
||||
|
||||
async def insert(
|
||||
self,
|
||||
embeddings: list[list[float]],
|
||||
@@ -149,34 +150,34 @@ class MilvusClient:
|
||||
) -> list[int]:
|
||||
"""
|
||||
Insert vectors with data into the collection.
|
||||
|
||||
|
||||
Args:
|
||||
embeddings: List of embedding vectors
|
||||
data: List of dicts with field values
|
||||
|
||||
|
||||
Returns:
|
||||
List of inserted IDs
|
||||
"""
|
||||
if not self._collection:
|
||||
raise RuntimeError("Not connected to collection")
|
||||
|
||||
|
||||
with create_span("milvus.insert") as span:
|
||||
if span:
|
||||
span.set_attribute("milvus.collection", self._collection.name)
|
||||
span.set_attribute("milvus.count", len(embeddings))
|
||||
|
||||
|
||||
# Build insert data
|
||||
insert_data = [embeddings]
|
||||
for field in self._collection.schema.fields:
|
||||
if field.name not in ("id", "embedding"):
|
||||
field_values = [d.get(field.name) for d in data]
|
||||
insert_data.append(field_values)
|
||||
|
||||
|
||||
result = self._collection.insert(insert_data)
|
||||
self._collection.flush()
|
||||
|
||||
|
||||
return result.primary_keys
|
||||
|
||||
|
||||
def health(self) -> bool:
|
||||
"""Check if connected to Milvus."""
|
||||
return self._connected and utility.get_connection_addr("default") is not None
|
||||
|
||||
Reference in New Issue
Block a user