""" Ray Serve deployment for sentence-transformers BGE embeddings. Runs on: drizzt (Radeon 680M iGPU, ROCm) """ import os import time from typing import Any from ray import serve from ray_serve.mlflow_logger import InferenceLogger @serve.deployment(name="EmbeddingsDeployment", num_replicas=1) class EmbeddingsDeployment: def __init__(self): import torch from sentence_transformers import SentenceTransformer self.model_id = os.environ.get("MODEL_ID", "BAAI/bge-large-en-v1.5") # Detect device if torch.cuda.is_available(): self.device = "cuda" elif hasattr(torch, "xpu") and torch.xpu.is_available(): self.device = "xpu" else: self.device = "cpu" print(f"Loading embeddings model: {self.model_id}") print(f"Using device: {self.device}") self.model = SentenceTransformer(self.model_id, device=self.device) self.embedding_dim = self.model.get_sentence_embedding_dimension() print(f"Model loaded. Embedding dimension: {self.embedding_dim}") # MLflow metrics self._mlflow = InferenceLogger( experiment_name="ray-serve-embeddings", run_name=f"embeddings-{self.model_id.split('/')[-1]}", tags={"model.name": self.model_id, "model.framework": "sentence-transformers", "device": self.device}, flush_every=10, ) self._mlflow.initialize( params={"model_id": self.model_id, "embedding_dim": str(self.embedding_dim), "device": self.device} ) async def __call__(self, request: dict[str, Any]) -> dict[str, Any]: """ Handle OpenAI-compatible embedding requests. Expected request format: { "model": "model-name", "input": "text to embed" or ["text1", "text2"], "encoding_format": "float" } """ input_data = request.get("input", "") _start = time.time() # Handle both single string and list of strings texts = [input_data] if isinstance(input_data, str) else input_data # Generate embeddings embeddings = self.model.encode( texts, normalize_embeddings=True, show_progress_bar=False, ) # Build response data data = [] total_tokens = 0 for i, (text, embedding) in enumerate(zip(texts, embeddings, strict=False)): data.append( { "object": "embedding", "index": i, "embedding": embedding.tolist(), } ) total_tokens += len(text.split()) # Log to MLflow self._mlflow.log_request( latency_s=time.time() - _start, batch_size=len(texts), total_tokens=total_tokens, ) # Return OpenAI-compatible response return { "object": "list", "data": data, "model": self.model_id, "usage": { "prompt_tokens": total_tokens, "total_tokens": total_tokens, }, } app = EmbeddingsDeployment.bind()