fix: add stop_token_ids and clamp max_tokens
All checks were successful
Build and Publish ray-serve-apps / build-and-publish (push) Successful in 12s

- Add Llama 3 stop token IDs (128001, 128009) to SamplingParams
  as safety net for V1 engine max_tokens bug on ROCm/gfx1151
- Clamp max_tokens to min(requested, max_model_len)
- Support DEFAULT_MAX_TOKENS env var (default 256)
- Prevents runaway generation when V1 engine ignores max_tokens
This commit is contained in:
2026-02-13 09:19:20 -05:00
parent 96f7650b23
commit 79dbaa6d2c

View File

@@ -83,7 +83,15 @@ class LLMDeployment:
engine_args = AsyncEngineArgs(**engine_kwargs)
self.engine = AsyncLLMEngine.from_engine_args(engine_args)
self.SamplingParams = SamplingParams
self.default_max_tokens = int(os.environ.get("DEFAULT_MAX_TOKENS", "256"))
# Llama 3 stop tokens — safety net in case V1 engine ignores max_tokens
self._stop_token_ids = [
128001, # <|end_of_text|>
128009, # <|eot_id|>
]
print(f"Model {self.model_id} async engine created")
print(f"Default max tokens: {self.default_max_tokens}")
print(f"Stop token IDs: {self._stop_token_ids}")
# MLflow metrics — import locally to avoid cloudpickle
# serializing a module reference that fails on the worker
@@ -128,7 +136,10 @@ class LLMDeployment:
"""
messages = request.get("messages", [])
temperature = request.get("temperature", 0.7)
max_tokens = request.get("max_tokens", 256)
max_tokens = min(
request.get("max_tokens", self.default_max_tokens),
self.max_model_len,
)
top_p = request.get("top_p", 1.0)
stop = request.get("stop")
@@ -140,6 +151,7 @@ class LLMDeployment:
max_tokens=max_tokens,
top_p=top_p,
stop=stop,
stop_token_ids=self._stop_token_ids,
)
start_time = time.time()