fix: add stop_token_ids and clamp max_tokens
All checks were successful
Build and Publish ray-serve-apps / build-and-publish (push) Successful in 12s
All checks were successful
Build and Publish ray-serve-apps / build-and-publish (push) Successful in 12s
- Add Llama 3 stop token IDs (128001, 128009) to SamplingParams as safety net for V1 engine max_tokens bug on ROCm/gfx1151 - Clamp max_tokens to min(requested, max_model_len) - Support DEFAULT_MAX_TOKENS env var (default 256) - Prevents runaway generation when V1 engine ignores max_tokens
This commit is contained in:
@@ -83,7 +83,15 @@ class LLMDeployment:
|
||||
engine_args = AsyncEngineArgs(**engine_kwargs)
|
||||
self.engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||
self.SamplingParams = SamplingParams
|
||||
self.default_max_tokens = int(os.environ.get("DEFAULT_MAX_TOKENS", "256"))
|
||||
# Llama 3 stop tokens — safety net in case V1 engine ignores max_tokens
|
||||
self._stop_token_ids = [
|
||||
128001, # <|end_of_text|>
|
||||
128009, # <|eot_id|>
|
||||
]
|
||||
print(f"Model {self.model_id} async engine created")
|
||||
print(f"Default max tokens: {self.default_max_tokens}")
|
||||
print(f"Stop token IDs: {self._stop_token_ids}")
|
||||
|
||||
# MLflow metrics — import locally to avoid cloudpickle
|
||||
# serializing a module reference that fails on the worker
|
||||
@@ -128,7 +136,10 @@ class LLMDeployment:
|
||||
"""
|
||||
messages = request.get("messages", [])
|
||||
temperature = request.get("temperature", 0.7)
|
||||
max_tokens = request.get("max_tokens", 256)
|
||||
max_tokens = min(
|
||||
request.get("max_tokens", self.default_max_tokens),
|
||||
self.max_model_len,
|
||||
)
|
||||
top_p = request.get("top_p", 1.0)
|
||||
stop = request.get("stop")
|
||||
|
||||
@@ -140,6 +151,7 @@ class LLMDeployment:
|
||||
max_tokens=max_tokens,
|
||||
top_p=top_p,
|
||||
stop=stop,
|
||||
stop_token_ids=self._stop_token_ids,
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
Reference in New Issue
Block a user