All checks were successful
Build and Publish ray-serve-apps / build-and-publish (push) Successful in 11s
The strixhalo LLM worker uses py_executable pointing to the Docker image venv which doesn't have the updated ray-serve-apps package. Wrap all InferenceLogger imports in try/except and guard usage with None checks so apps degrade gracefully without MLflow logging.
115 lines
3.4 KiB
Python
115 lines
3.4 KiB
Python
"""
|
|
Ray Serve deployment for sentence-transformers BGE embeddings.
|
|
Runs on: drizzt (Radeon 680M iGPU, ROCm)
|
|
"""
|
|
|
|
import os
|
|
import time
|
|
from typing import Any
|
|
|
|
from ray import serve
|
|
|
|
try:
|
|
from ray_serve.mlflow_logger import InferenceLogger
|
|
except ImportError:
|
|
InferenceLogger = None
|
|
|
|
|
|
@serve.deployment(name="EmbeddingsDeployment", num_replicas=1)
|
|
class EmbeddingsDeployment:
|
|
def __init__(self):
|
|
import torch
|
|
from sentence_transformers import SentenceTransformer
|
|
|
|
self.model_id = os.environ.get("MODEL_ID", "BAAI/bge-large-en-v1.5")
|
|
|
|
# Detect device
|
|
if torch.cuda.is_available():
|
|
self.device = "cuda"
|
|
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
|
self.device = "xpu"
|
|
else:
|
|
self.device = "cpu"
|
|
|
|
print(f"Loading embeddings model: {self.model_id}")
|
|
print(f"Using device: {self.device}")
|
|
|
|
self.model = SentenceTransformer(self.model_id, device=self.device)
|
|
self.embedding_dim = self.model.get_sentence_embedding_dimension()
|
|
|
|
print(f"Model loaded. Embedding dimension: {self.embedding_dim}")
|
|
|
|
# MLflow metrics
|
|
if InferenceLogger is not None:
|
|
self._mlflow = InferenceLogger(
|
|
experiment_name="ray-serve-embeddings",
|
|
run_name=f"embeddings-{self.model_id.split('/')[-1]}",
|
|
tags={"model.name": self.model_id, "model.framework": "sentence-transformers", "device": self.device},
|
|
flush_every=10,
|
|
)
|
|
self._mlflow.initialize(
|
|
params={"model_id": self.model_id, "embedding_dim": str(self.embedding_dim), "device": self.device}
|
|
)
|
|
else:
|
|
self._mlflow = None
|
|
|
|
async def __call__(self, request: dict[str, Any]) -> dict[str, Any]:
|
|
"""
|
|
Handle OpenAI-compatible embedding requests.
|
|
|
|
Expected request format:
|
|
{
|
|
"model": "model-name",
|
|
"input": "text to embed" or ["text1", "text2"],
|
|
"encoding_format": "float"
|
|
}
|
|
"""
|
|
input_data = request.get("input", "")
|
|
|
|
_start = time.time()
|
|
|
|
# Handle both single string and list of strings
|
|
texts = [input_data] if isinstance(input_data, str) else input_data
|
|
|
|
# Generate embeddings
|
|
embeddings = self.model.encode(
|
|
texts,
|
|
normalize_embeddings=True,
|
|
show_progress_bar=False,
|
|
)
|
|
|
|
# Build response data
|
|
data = []
|
|
total_tokens = 0
|
|
for i, (text, embedding) in enumerate(zip(texts, embeddings, strict=False)):
|
|
data.append(
|
|
{
|
|
"object": "embedding",
|
|
"index": i,
|
|
"embedding": embedding.tolist(),
|
|
}
|
|
)
|
|
total_tokens += len(text.split())
|
|
|
|
# Log to MLflow
|
|
if self._mlflow:
|
|
self._mlflow.log_request(
|
|
latency_s=time.time() - _start,
|
|
batch_size=len(texts),
|
|
total_tokens=total_tokens,
|
|
)
|
|
|
|
# Return OpenAI-compatible response
|
|
return {
|
|
"object": "list",
|
|
"data": data,
|
|
"model": self.model_id,
|
|
"usage": {
|
|
"prompt_tokens": total_tokens,
|
|
"total_tokens": total_tokens,
|
|
},
|
|
}
|
|
|
|
|
|
app = EmbeddingsDeployment.bind()
|