""" Ray Serve deployment for vLLM with OpenAI-compatible API. Runs on: khelben (Strix Halo 64GB, ROCm) """ import os import time import uuid from typing import Any from ray import serve @serve.deployment(name="LLMDeployment", num_replicas=1) class LLMDeployment: def __init__(self): # Workaround: vLLM's rocm.py:verify_quantization unconditionally # sets VLLM_USE_TRITON_AWQ=1 (a bug — the os.environ line is # outside the if-block). Triton AWQ kernels can't compile for # gfx1151, so we must keep it at 0. Monkey-patch before engine # creation to prevent the override. os.environ["VLLM_USE_TRITON_AWQ"] = "0" try: from vllm.platforms.rocm import RocmPlatform _orig_verify = RocmPlatform.verify_quantization.__func__ @classmethod # type: ignore[misc] def _patched_verify(cls, quant: str) -> None: saved = os.environ.get("VLLM_USE_TRITON_AWQ") _orig_verify(cls, quant) # Restore our value after vLLM clobbers it if saved is not None: os.environ["VLLM_USE_TRITON_AWQ"] = saved RocmPlatform.verify_quantization = _patched_verify print("Patched RocmPlatform.verify_quantization to preserve VLLM_USE_TRITON_AWQ=0") except Exception as exc: print(f"Could not patch RocmPlatform: {exc}") from vllm import SamplingParams from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine self.model_id = os.environ.get("MODEL_ID", "meta-llama/Llama-3.1-70B-Instruct") self.max_model_len = int(os.environ.get("MAX_MODEL_LEN", "8192")) self.gpu_memory_utilization = float(os.environ.get("GPU_MEMORY_UTILIZATION", "0.9")) print(f"Loading vLLM model: {self.model_id}") print(f"Max model length: {self.max_model_len}") print(f"GPU memory utilization: {self.gpu_memory_utilization}") self.enable_prefix_caching = os.environ.get("ENABLE_PREFIX_CACHING", "true").lower() == "true" self.enable_chunked_prefill = os.environ.get("ENABLE_CHUNKED_PREFILL", "true").lower() == "true" self.num_speculative_tokens = int(os.environ.get("NUM_SPECULATIVE_TOKENS", "0")) self.ngram_prompt_lookup_max = int(os.environ.get("NGRAM_PROMPT_LOOKUP_MAX", "0")) self.enforce_eager = os.environ.get("ENFORCE_EAGER", "false").lower() == "true" engine_kwargs: dict[str, Any] = { "model": self.model_id, "max_model_len": self.max_model_len, "gpu_memory_utilization": self.gpu_memory_utilization, "trust_remote_code": True, "enable_prefix_caching": self.enable_prefix_caching, "enable_chunked_prefill": self.enable_chunked_prefill, "enforce_eager": self.enforce_eager, } # n-gram speculative decoding (no draft model needed) if self.num_speculative_tokens > 0 and self.ngram_prompt_lookup_max > 0: engine_kwargs["speculative_config"] = { "method": "ngram", "num_speculative_tokens": self.num_speculative_tokens, "ngram_prompt_lookup_max": self.ngram_prompt_lookup_max, } print(f"Speculative decoding: ngram n={self.ngram_prompt_lookup_max}, k={self.num_speculative_tokens}") print(f"Prefix caching: {self.enable_prefix_caching}") print(f"Chunked prefill: {self.enable_chunked_prefill}") print(f"Enforce eager (no torch.compile): {self.enforce_eager}") engine_args = AsyncEngineArgs(**engine_kwargs) self.engine = AsyncLLMEngine.from_engine_args(engine_args) self.SamplingParams = SamplingParams print(f"Model {self.model_id} async engine created") # MLflow metrics — import locally to avoid cloudpickle # serializing a module reference that fails on the worker # (strixhalo uses py_executable which bypasses pip runtime_env) try: from ray_serve.mlflow_logger import InferenceLogger self._mlflow = InferenceLogger( experiment_name="ray-serve-llm", run_name=f"llm-{self.model_id.split('/')[-1]}", tags={"model.name": self.model_id, "model.framework": "vllm", "gpu": "strixhalo"}, flush_every=5, ) self._mlflow.initialize( params={ "model_id": self.model_id, "max_model_len": str(self.max_model_len), "gpu_memory_utilization": str(self.gpu_memory_utilization), "enable_prefix_caching": str(self.enable_prefix_caching), "enable_chunked_prefill": str(self.enable_chunked_prefill), "num_speculative_tokens": str(self.num_speculative_tokens), "ngram_prompt_lookup_max": str(self.ngram_prompt_lookup_max), "enforce_eager": str(self.enforce_eager), } ) except ImportError: self._mlflow = None async def __call__(self, request: dict[str, Any]) -> dict[str, Any]: """ Handle OpenAI-compatible chat completion requests. Expected request format: { "model": "model-name", "messages": [{"role": "user", "content": "Hello"}], "temperature": 0.7, "max_tokens": 256, "top_p": 1.0, "stream": false } """ messages = request.get("messages", []) temperature = request.get("temperature", 0.7) max_tokens = request.get("max_tokens", 256) top_p = request.get("top_p", 1.0) stop = request.get("stop") # Convert messages to prompt prompt = self._format_messages(messages) sampling_params = self.SamplingParams( temperature=temperature, max_tokens=max_tokens, top_p=top_p, stop=stop, ) start_time = time.time() request_id = uuid.uuid4().hex final_result = None async for result in self.engine.generate(prompt, sampling_params, request_id): final_result = result generated_text = final_result.outputs[0].text latency = time.time() - start_time # Use actual output token count from vLLM when available vllm_output = final_result.outputs[0] completion_tokens = len(vllm_output.token_ids) if hasattr(vllm_output, "token_ids") else len(generated_text.split()) prompt_tokens = len(final_result.prompt_token_ids) if hasattr(final_result, "prompt_token_ids") else len(prompt.split()) # Log to MLflow if self._mlflow: self._mlflow.log_request( latency_s=latency, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=prompt_tokens + completion_tokens, tokens_per_second=completion_tokens / latency if latency > 0 else 0, temperature=temperature, max_tokens_requested=max_tokens, ) # Return OpenAI-compatible response return { "id": f"chatcmpl-{uuid.uuid4().hex[:8]}", "object": "chat.completion", "created": int(time.time()), "model": self.model_id, "choices": [ { "index": 0, "message": { "role": "assistant", "content": generated_text, }, "finish_reason": "stop", } ], "usage": { "prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens, "total_tokens": prompt_tokens + completion_tokens, }, } def _format_messages(self, messages: list[dict[str, str]]) -> str: """Format chat messages into a prompt string.""" formatted = "" for msg in messages: role = msg.get("role", "user") content = msg.get("content", "") if role == "system": formatted += f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{content}<|eot_id|>" elif role == "user": formatted += f"<|start_header_id|>user<|end_header_id|>\n\n{content}<|eot_id|>" elif role == "assistant": formatted += f"<|start_header_id|>assistant<|end_header_id|>\n\n{content}<|eot_id|>" formatted += "<|start_header_id|>assistant<|end_header_id|>\n\n" return formatted app = LLMDeployment.bind()