diff --git a/ray_serve/serve_llm.py b/ray_serve/serve_llm.py index 8034217..c067860 100644 --- a/ray_serve/serve_llm.py +++ b/ray_serve/serve_llm.py @@ -10,11 +10,6 @@ from typing import Any from ray import serve -try: - from ray_serve.mlflow_logger import InferenceLogger -except ImportError: - InferenceLogger = None - @serve.deployment(name="LLMDeployment", num_replicas=1) class LLMDeployment: @@ -42,8 +37,12 @@ class LLMDeployment: self.SamplingParams = SamplingParams print(f"Model {self.model_id} async engine created") - # MLflow metrics - if InferenceLogger is not None: + # MLflow metrics — import locally to avoid cloudpickle + # serializing a module reference that fails on the worker + # (strixhalo uses py_executable which bypasses pip runtime_env) + try: + from ray_serve.mlflow_logger import InferenceLogger + self._mlflow = InferenceLogger( experiment_name="ray-serve-llm", run_name=f"llm-{self.model_id.split('/')[-1]}", @@ -57,7 +56,7 @@ class LLMDeployment: "gpu_memory_utilization": str(self.gpu_memory_utilization), } ) - else: + except ImportError: self._mlflow = None async def __call__(self, request: dict[str, Any]) -> dict[str, Any]: