""" Ray Serve deployment for vLLM with OpenAI-compatible API. Runs on: khelben (Strix Halo 64GB, ROCm) """ import os import time import uuid from typing import Any from ray import serve try: from ray_serve.mlflow_logger import InferenceLogger except ImportError: InferenceLogger = None @serve.deployment(name="LLMDeployment", num_replicas=1) class LLMDeployment: def __init__(self): from vllm import SamplingParams from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine self.model_id = os.environ.get("MODEL_ID", "meta-llama/Llama-3.1-70B-Instruct") self.max_model_len = int(os.environ.get("MAX_MODEL_LEN", "8192")) self.gpu_memory_utilization = float(os.environ.get("GPU_MEMORY_UTILIZATION", "0.9")) print(f"Loading vLLM model: {self.model_id}") print(f"Max model length: {self.max_model_len}") print(f"GPU memory utilization: {self.gpu_memory_utilization}") engine_args = AsyncEngineArgs( model=self.model_id, max_model_len=self.max_model_len, gpu_memory_utilization=self.gpu_memory_utilization, trust_remote_code=True, disable_log_stats=True, ) self.engine = AsyncLLMEngine.from_engine_args(engine_args) self.SamplingParams = SamplingParams print(f"Model {self.model_id} async engine created") # MLflow metrics if InferenceLogger is not None: self._mlflow = InferenceLogger( experiment_name="ray-serve-llm", run_name=f"llm-{self.model_id.split('/')[-1]}", tags={"model.name": self.model_id, "model.framework": "vllm", "gpu": "strixhalo"}, flush_every=5, ) self._mlflow.initialize( params={ "model_id": self.model_id, "max_model_len": str(self.max_model_len), "gpu_memory_utilization": str(self.gpu_memory_utilization), } ) else: self._mlflow = None async def __call__(self, request: dict[str, Any]) -> dict[str, Any]: """ Handle OpenAI-compatible chat completion requests. Expected request format: { "model": "model-name", "messages": [{"role": "user", "content": "Hello"}], "temperature": 0.7, "max_tokens": 256, "top_p": 1.0, "stream": false } """ messages = request.get("messages", []) temperature = request.get("temperature", 0.7) max_tokens = request.get("max_tokens", 256) top_p = request.get("top_p", 1.0) stop = request.get("stop") # Convert messages to prompt prompt = self._format_messages(messages) sampling_params = self.SamplingParams( temperature=temperature, max_tokens=max_tokens, top_p=top_p, stop=stop, ) start_time = time.time() request_id = uuid.uuid4().hex final_result = None async for result in self.engine.generate(prompt, sampling_params, request_id): final_result = result generated_text = final_result.outputs[0].text latency = time.time() - start_time prompt_tokens = len(prompt.split()) completion_tokens = len(generated_text.split()) # Log to MLflow if self._mlflow: self._mlflow.log_request( latency_s=latency, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=prompt_tokens + completion_tokens, tokens_per_second=completion_tokens / latency if latency > 0 else 0, temperature=temperature, max_tokens_requested=max_tokens, ) # Return OpenAI-compatible response return { "id": f"chatcmpl-{uuid.uuid4().hex[:8]}", "object": "chat.completion", "created": int(time.time()), "model": self.model_id, "choices": [ { "index": 0, "message": { "role": "assistant", "content": generated_text, }, "finish_reason": "stop", } ], "usage": { "prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens, "total_tokens": prompt_tokens + completion_tokens, }, } def _format_messages(self, messages: list[dict[str, str]]) -> str: """Format chat messages into a prompt string.""" formatted = "" for msg in messages: role = msg.get("role", "user") content = msg.get("content", "") if role == "system": formatted += f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{content}<|eot_id|>" elif role == "user": formatted += f"<|start_header_id|>user<|end_header_id|>\n\n{content}<|eot_id|>" elif role == "assistant": formatted += f"<|start_header_id|>assistant<|end_header_id|>\n\n{content}<|eot_id|>" formatted += "<|start_header_id|>assistant<|end_header_id|>\n\n" return formatted app = LLMDeployment.bind()