From 79dbaa6d2ca5ef737214cd231fe4d95b487f38af Mon Sep 17 00:00:00 2001 From: "Billy D." Date: Fri, 13 Feb 2026 09:19:20 -0500 Subject: [PATCH] fix: add stop_token_ids and clamp max_tokens - Add Llama 3 stop token IDs (128001, 128009) to SamplingParams as safety net for V1 engine max_tokens bug on ROCm/gfx1151 - Clamp max_tokens to min(requested, max_model_len) - Support DEFAULT_MAX_TOKENS env var (default 256) - Prevents runaway generation when V1 engine ignores max_tokens --- ray_serve/serve_llm.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/ray_serve/serve_llm.py b/ray_serve/serve_llm.py index 1992083..9b44b0c 100644 --- a/ray_serve/serve_llm.py +++ b/ray_serve/serve_llm.py @@ -83,7 +83,15 @@ class LLMDeployment: engine_args = AsyncEngineArgs(**engine_kwargs) self.engine = AsyncLLMEngine.from_engine_args(engine_args) self.SamplingParams = SamplingParams + self.default_max_tokens = int(os.environ.get("DEFAULT_MAX_TOKENS", "256")) + # Llama 3 stop tokens — safety net in case V1 engine ignores max_tokens + self._stop_token_ids = [ + 128001, # <|end_of_text|> + 128009, # <|eot_id|> + ] print(f"Model {self.model_id} async engine created") + print(f"Default max tokens: {self.default_max_tokens}") + print(f"Stop token IDs: {self._stop_token_ids}") # MLflow metrics — import locally to avoid cloudpickle # serializing a module reference that fails on the worker @@ -128,7 +136,10 @@ class LLMDeployment: """ messages = request.get("messages", []) temperature = request.get("temperature", 0.7) - max_tokens = request.get("max_tokens", 256) + max_tokens = min( + request.get("max_tokens", self.default_max_tokens), + self.max_model_len, + ) top_p = request.get("top_p", 1.0) stop = request.get("stop") @@ -140,6 +151,7 @@ class LLMDeployment: max_tokens=max_tokens, top_p=top_p, stop=stop, + stop_token_ids=self._stop_token_ids, ) start_time = time.time()