diff --git a/ray_serve/serve_llm.py b/ray_serve/serve_llm.py index 1992083..9b44b0c 100644 --- a/ray_serve/serve_llm.py +++ b/ray_serve/serve_llm.py @@ -83,7 +83,15 @@ class LLMDeployment: engine_args = AsyncEngineArgs(**engine_kwargs) self.engine = AsyncLLMEngine.from_engine_args(engine_args) self.SamplingParams = SamplingParams + self.default_max_tokens = int(os.environ.get("DEFAULT_MAX_TOKENS", "256")) + # Llama 3 stop tokens — safety net in case V1 engine ignores max_tokens + self._stop_token_ids = [ + 128001, # <|end_of_text|> + 128009, # <|eot_id|> + ] print(f"Model {self.model_id} async engine created") + print(f"Default max tokens: {self.default_max_tokens}") + print(f"Stop token IDs: {self._stop_token_ids}") # MLflow metrics — import locally to avoid cloudpickle # serializing a module reference that fails on the worker @@ -128,7 +136,10 @@ class LLMDeployment: """ messages = request.get("messages", []) temperature = request.get("temperature", 0.7) - max_tokens = request.get("max_tokens", 256) + max_tokens = min( + request.get("max_tokens", self.default_max_tokens), + self.max_model_len, + ) top_p = request.get("top_p", 1.0) stop = request.get("stop") @@ -140,6 +151,7 @@ class LLMDeployment: max_tokens=max_tokens, top_p=top_p, stop=stop, + stop_token_ids=self._stop_token_ids, ) start_time = time.time()