minor: refactoring big changes.
All checks were successful
Build and Publish ray-serve-apps / build-and-publish (push) Successful in 12s
All checks were successful
Build and Publish ray-serve-apps / build-and-publish (push) Successful in 12s
This commit is contained in:
@@ -14,6 +14,30 @@ from ray import serve
|
|||||||
@serve.deployment(name="LLMDeployment", num_replicas=1)
|
@serve.deployment(name="LLMDeployment", num_replicas=1)
|
||||||
class LLMDeployment:
|
class LLMDeployment:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
# Workaround: vLLM's rocm.py:verify_quantization unconditionally
|
||||||
|
# sets VLLM_USE_TRITON_AWQ=1 (a bug — the os.environ line is
|
||||||
|
# outside the if-block). Triton AWQ kernels can't compile for
|
||||||
|
# gfx1151, so we must keep it at 0. Monkey-patch before engine
|
||||||
|
# creation to prevent the override.
|
||||||
|
os.environ["VLLM_USE_TRITON_AWQ"] = "0"
|
||||||
|
try:
|
||||||
|
from vllm.platforms.rocm import RocmPlatform
|
||||||
|
|
||||||
|
_orig_verify = RocmPlatform.verify_quantization.__func__
|
||||||
|
|
||||||
|
@classmethod # type: ignore[misc]
|
||||||
|
def _patched_verify(cls, quant: str) -> None:
|
||||||
|
saved = os.environ.get("VLLM_USE_TRITON_AWQ")
|
||||||
|
_orig_verify(cls, quant)
|
||||||
|
# Restore our value after vLLM clobbers it
|
||||||
|
if saved is not None:
|
||||||
|
os.environ["VLLM_USE_TRITON_AWQ"] = saved
|
||||||
|
|
||||||
|
RocmPlatform.verify_quantization = _patched_verify
|
||||||
|
print("Patched RocmPlatform.verify_quantization to preserve VLLM_USE_TRITON_AWQ=0")
|
||||||
|
except Exception as exc:
|
||||||
|
print(f"Could not patch RocmPlatform: {exc}")
|
||||||
|
|
||||||
from vllm import SamplingParams
|
from vllm import SamplingParams
|
||||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||||
@@ -26,13 +50,33 @@ class LLMDeployment:
|
|||||||
print(f"Max model length: {self.max_model_len}")
|
print(f"Max model length: {self.max_model_len}")
|
||||||
print(f"GPU memory utilization: {self.gpu_memory_utilization}")
|
print(f"GPU memory utilization: {self.gpu_memory_utilization}")
|
||||||
|
|
||||||
engine_args = AsyncEngineArgs(
|
self.enable_prefix_caching = os.environ.get("ENABLE_PREFIX_CACHING", "true").lower() == "true"
|
||||||
model=self.model_id,
|
self.enable_chunked_prefill = os.environ.get("ENABLE_CHUNKED_PREFILL", "true").lower() == "true"
|
||||||
max_model_len=self.max_model_len,
|
self.num_speculative_tokens = int(os.environ.get("NUM_SPECULATIVE_TOKENS", "0"))
|
||||||
gpu_memory_utilization=self.gpu_memory_utilization,
|
self.ngram_prompt_lookup_max = int(os.environ.get("NGRAM_PROMPT_LOOKUP_MAX", "0"))
|
||||||
trust_remote_code=True,
|
|
||||||
disable_log_stats=True,
|
engine_kwargs: dict[str, Any] = {
|
||||||
)
|
"model": self.model_id,
|
||||||
|
"max_model_len": self.max_model_len,
|
||||||
|
"gpu_memory_utilization": self.gpu_memory_utilization,
|
||||||
|
"trust_remote_code": True,
|
||||||
|
"enable_prefix_caching": self.enable_prefix_caching,
|
||||||
|
"enable_chunked_prefill": self.enable_chunked_prefill,
|
||||||
|
}
|
||||||
|
|
||||||
|
# n-gram speculative decoding (no draft model needed)
|
||||||
|
if self.num_speculative_tokens > 0 and self.ngram_prompt_lookup_max > 0:
|
||||||
|
engine_kwargs["speculative_config"] = {
|
||||||
|
"method": "ngram",
|
||||||
|
"num_speculative_tokens": self.num_speculative_tokens,
|
||||||
|
"ngram_prompt_lookup_max": self.ngram_prompt_lookup_max,
|
||||||
|
}
|
||||||
|
print(f"Speculative decoding: ngram n={self.ngram_prompt_lookup_max}, k={self.num_speculative_tokens}")
|
||||||
|
|
||||||
|
print(f"Prefix caching: {self.enable_prefix_caching}")
|
||||||
|
print(f"Chunked prefill: {self.enable_chunked_prefill}")
|
||||||
|
|
||||||
|
engine_args = AsyncEngineArgs(**engine_kwargs)
|
||||||
self.engine = AsyncLLMEngine.from_engine_args(engine_args)
|
self.engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||||
self.SamplingParams = SamplingParams
|
self.SamplingParams = SamplingParams
|
||||||
print(f"Model {self.model_id} async engine created")
|
print(f"Model {self.model_id} async engine created")
|
||||||
@@ -54,6 +98,10 @@ class LLMDeployment:
|
|||||||
"model_id": self.model_id,
|
"model_id": self.model_id,
|
||||||
"max_model_len": str(self.max_model_len),
|
"max_model_len": str(self.max_model_len),
|
||||||
"gpu_memory_utilization": str(self.gpu_memory_utilization),
|
"gpu_memory_utilization": str(self.gpu_memory_utilization),
|
||||||
|
"enable_prefix_caching": str(self.enable_prefix_caching),
|
||||||
|
"enable_chunked_prefill": str(self.enable_chunked_prefill),
|
||||||
|
"num_speculative_tokens": str(self.num_speculative_tokens),
|
||||||
|
"ngram_prompt_lookup_max": str(self.ngram_prompt_lookup_max),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@@ -97,8 +145,10 @@ class LLMDeployment:
|
|||||||
generated_text = final_result.outputs[0].text
|
generated_text = final_result.outputs[0].text
|
||||||
latency = time.time() - start_time
|
latency = time.time() - start_time
|
||||||
|
|
||||||
prompt_tokens = len(prompt.split())
|
# Use actual output token count from vLLM when available
|
||||||
completion_tokens = len(generated_text.split())
|
vllm_output = final_result.outputs[0]
|
||||||
|
completion_tokens = len(vllm_output.token_ids) if hasattr(vllm_output, "token_ids") else len(generated_text.split())
|
||||||
|
prompt_tokens = len(final_result.prompt_token_ids) if hasattr(final_result, "prompt_token_ids") else len(prompt.split())
|
||||||
|
|
||||||
# Log to MLflow
|
# Log to MLflow
|
||||||
if self._mlflow:
|
if self._mlflow:
|
||||||
|
|||||||
Reference in New Issue
Block a user