fix: add stop_token_ids and clamp max_tokens
All checks were successful
Build and Publish ray-serve-apps / build-and-publish (push) Successful in 12s
All checks were successful
Build and Publish ray-serve-apps / build-and-publish (push) Successful in 12s
- Add Llama 3 stop token IDs (128001, 128009) to SamplingParams as safety net for V1 engine max_tokens bug on ROCm/gfx1151 - Clamp max_tokens to min(requested, max_model_len) - Support DEFAULT_MAX_TOKENS env var (default 256) - Prevents runaway generation when V1 engine ignores max_tokens
This commit is contained in:
@@ -83,7 +83,15 @@ class LLMDeployment:
|
|||||||
engine_args = AsyncEngineArgs(**engine_kwargs)
|
engine_args = AsyncEngineArgs(**engine_kwargs)
|
||||||
self.engine = AsyncLLMEngine.from_engine_args(engine_args)
|
self.engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||||
self.SamplingParams = SamplingParams
|
self.SamplingParams = SamplingParams
|
||||||
|
self.default_max_tokens = int(os.environ.get("DEFAULT_MAX_TOKENS", "256"))
|
||||||
|
# Llama 3 stop tokens — safety net in case V1 engine ignores max_tokens
|
||||||
|
self._stop_token_ids = [
|
||||||
|
128001, # <|end_of_text|>
|
||||||
|
128009, # <|eot_id|>
|
||||||
|
]
|
||||||
print(f"Model {self.model_id} async engine created")
|
print(f"Model {self.model_id} async engine created")
|
||||||
|
print(f"Default max tokens: {self.default_max_tokens}")
|
||||||
|
print(f"Stop token IDs: {self._stop_token_ids}")
|
||||||
|
|
||||||
# MLflow metrics — import locally to avoid cloudpickle
|
# MLflow metrics — import locally to avoid cloudpickle
|
||||||
# serializing a module reference that fails on the worker
|
# serializing a module reference that fails on the worker
|
||||||
@@ -128,7 +136,10 @@ class LLMDeployment:
|
|||||||
"""
|
"""
|
||||||
messages = request.get("messages", [])
|
messages = request.get("messages", [])
|
||||||
temperature = request.get("temperature", 0.7)
|
temperature = request.get("temperature", 0.7)
|
||||||
max_tokens = request.get("max_tokens", 256)
|
max_tokens = min(
|
||||||
|
request.get("max_tokens", self.default_max_tokens),
|
||||||
|
self.max_model_len,
|
||||||
|
)
|
||||||
top_p = request.get("top_p", 1.0)
|
top_p = request.get("top_p", 1.0)
|
||||||
stop = request.get("stop")
|
stop = request.get("stop")
|
||||||
|
|
||||||
@@ -140,6 +151,7 @@ class LLMDeployment:
|
|||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
stop=stop,
|
stop=stop,
|
||||||
|
stop_token_ids=self._stop_token_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|||||||
Reference in New Issue
Block a user