minor: refactoring big changes.
All checks were successful
Build and Publish ray-serve-apps / build-and-publish (push) Successful in 12s
All checks were successful
Build and Publish ray-serve-apps / build-and-publish (push) Successful in 12s
This commit is contained in:
@@ -14,6 +14,30 @@ from ray import serve
|
||||
@serve.deployment(name="LLMDeployment", num_replicas=1)
|
||||
class LLMDeployment:
|
||||
def __init__(self):
|
||||
# Workaround: vLLM's rocm.py:verify_quantization unconditionally
|
||||
# sets VLLM_USE_TRITON_AWQ=1 (a bug — the os.environ line is
|
||||
# outside the if-block). Triton AWQ kernels can't compile for
|
||||
# gfx1151, so we must keep it at 0. Monkey-patch before engine
|
||||
# creation to prevent the override.
|
||||
os.environ["VLLM_USE_TRITON_AWQ"] = "0"
|
||||
try:
|
||||
from vllm.platforms.rocm import RocmPlatform
|
||||
|
||||
_orig_verify = RocmPlatform.verify_quantization.__func__
|
||||
|
||||
@classmethod # type: ignore[misc]
|
||||
def _patched_verify(cls, quant: str) -> None:
|
||||
saved = os.environ.get("VLLM_USE_TRITON_AWQ")
|
||||
_orig_verify(cls, quant)
|
||||
# Restore our value after vLLM clobbers it
|
||||
if saved is not None:
|
||||
os.environ["VLLM_USE_TRITON_AWQ"] = saved
|
||||
|
||||
RocmPlatform.verify_quantization = _patched_verify
|
||||
print("Patched RocmPlatform.verify_quantization to preserve VLLM_USE_TRITON_AWQ=0")
|
||||
except Exception as exc:
|
||||
print(f"Could not patch RocmPlatform: {exc}")
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
@@ -26,13 +50,33 @@ class LLMDeployment:
|
||||
print(f"Max model length: {self.max_model_len}")
|
||||
print(f"GPU memory utilization: {self.gpu_memory_utilization}")
|
||||
|
||||
engine_args = AsyncEngineArgs(
|
||||
model=self.model_id,
|
||||
max_model_len=self.max_model_len,
|
||||
gpu_memory_utilization=self.gpu_memory_utilization,
|
||||
trust_remote_code=True,
|
||||
disable_log_stats=True,
|
||||
)
|
||||
self.enable_prefix_caching = os.environ.get("ENABLE_PREFIX_CACHING", "true").lower() == "true"
|
||||
self.enable_chunked_prefill = os.environ.get("ENABLE_CHUNKED_PREFILL", "true").lower() == "true"
|
||||
self.num_speculative_tokens = int(os.environ.get("NUM_SPECULATIVE_TOKENS", "0"))
|
||||
self.ngram_prompt_lookup_max = int(os.environ.get("NGRAM_PROMPT_LOOKUP_MAX", "0"))
|
||||
|
||||
engine_kwargs: dict[str, Any] = {
|
||||
"model": self.model_id,
|
||||
"max_model_len": self.max_model_len,
|
||||
"gpu_memory_utilization": self.gpu_memory_utilization,
|
||||
"trust_remote_code": True,
|
||||
"enable_prefix_caching": self.enable_prefix_caching,
|
||||
"enable_chunked_prefill": self.enable_chunked_prefill,
|
||||
}
|
||||
|
||||
# n-gram speculative decoding (no draft model needed)
|
||||
if self.num_speculative_tokens > 0 and self.ngram_prompt_lookup_max > 0:
|
||||
engine_kwargs["speculative_config"] = {
|
||||
"method": "ngram",
|
||||
"num_speculative_tokens": self.num_speculative_tokens,
|
||||
"ngram_prompt_lookup_max": self.ngram_prompt_lookup_max,
|
||||
}
|
||||
print(f"Speculative decoding: ngram n={self.ngram_prompt_lookup_max}, k={self.num_speculative_tokens}")
|
||||
|
||||
print(f"Prefix caching: {self.enable_prefix_caching}")
|
||||
print(f"Chunked prefill: {self.enable_chunked_prefill}")
|
||||
|
||||
engine_args = AsyncEngineArgs(**engine_kwargs)
|
||||
self.engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||
self.SamplingParams = SamplingParams
|
||||
print(f"Model {self.model_id} async engine created")
|
||||
@@ -54,6 +98,10 @@ class LLMDeployment:
|
||||
"model_id": self.model_id,
|
||||
"max_model_len": str(self.max_model_len),
|
||||
"gpu_memory_utilization": str(self.gpu_memory_utilization),
|
||||
"enable_prefix_caching": str(self.enable_prefix_caching),
|
||||
"enable_chunked_prefill": str(self.enable_chunked_prefill),
|
||||
"num_speculative_tokens": str(self.num_speculative_tokens),
|
||||
"ngram_prompt_lookup_max": str(self.ngram_prompt_lookup_max),
|
||||
}
|
||||
)
|
||||
except ImportError:
|
||||
@@ -97,8 +145,10 @@ class LLMDeployment:
|
||||
generated_text = final_result.outputs[0].text
|
||||
latency = time.time() - start_time
|
||||
|
||||
prompt_tokens = len(prompt.split())
|
||||
completion_tokens = len(generated_text.split())
|
||||
# Use actual output token count from vLLM when available
|
||||
vllm_output = final_result.outputs[0]
|
||||
completion_tokens = len(vllm_output.token_ids) if hasattr(vllm_output, "token_ids") else len(generated_text.split())
|
||||
prompt_tokens = len(final_result.prompt_token_ids) if hasattr(final_result, "prompt_token_ids") else len(prompt.split())
|
||||
|
||||
# Log to MLflow
|
||||
if self._mlflow:
|
||||
|
||||
Reference in New Issue
Block a user