feat: add MLflow inference logging to all Ray Serve apps
All checks were successful
Build and Publish ray-serve-apps / build-and-publish (push) Successful in 16s

- Add mlflow_logger.py: lightweight REST-based MLflow logger (no mlflow dep)
- Instrument serve_llm.py with latency, token counts, tokens/sec metrics
- Instrument serve_embeddings.py with latency, batch_size, total_tokens
- Instrument serve_whisper.py with latency, audio_duration, realtime_factor
- Instrument serve_tts.py with latency, audio_duration, text_chars
- Instrument serve_reranker.py with latency, num_pairs, top_k
This commit is contained in:
2026-02-12 06:14:30 -05:00
parent 2edafc33c0
commit 7ec2107e0c
6 changed files with 346 additions and 4 deletions

211
ray_serve/mlflow_logger.py Normal file
View File

@@ -0,0 +1,211 @@
"""
Lightweight MLflow metrics logger using the REST API.
Avoids importing the heavyweight mlflow package — uses only stdlib
urllib so it works inside any Ray Serve actor without extra pip deps.
Each deployment creates **one persistent MLflow run** on startup and
logs per-request metrics with an incrementing step counter. This
gives time-series charts in the MLflow UI. The run is terminated
when the actor shuts down (or left RUNNING if the process crashes).
"""
import atexit
import json
import logging
import os
import threading
import time
import urllib.error
import urllib.request
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Any
logger = logging.getLogger(__name__)
@dataclass
class _RunState:
run_id: str
experiment_id: str
step: int = 0
class InferenceLogger:
"""Per-deployment MLflow metrics logger backed by the REST API.
Parameters
----------
experiment_name:
MLflow experiment name (created if missing).
run_name:
Human-readable run name shown in the MLflow UI.
tracking_uri:
MLflow tracking server. Defaults to ``MLFLOW_TRACKING_URI`` env var
or the in-cluster service address.
tags:
Extra tags attached to the run (e.g. model name, GPU, node).
flush_every:
Batch this many metric points before flushing (reduces HTTP calls).
"""
def __init__(
self,
experiment_name: str,
run_name: str,
tracking_uri: str | None = None,
tags: dict[str, str] | None = None,
flush_every: int = 1,
):
self._base = (
tracking_uri
or os.environ.get("MLFLOW_TRACKING_URI", "http://mlflow.mlflow.svc.cluster.local:80")
).rstrip("/")
self._experiment_name = experiment_name
self._run_name = run_name
self._tags = tags or {}
self._flush_every = max(1, flush_every)
self._state: _RunState | None = None
self._buffer: list[dict[str, Any]] = []
self._lock = threading.Lock()
self._enabled = True
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
def initialize(self, params: dict[str, str] | None = None) -> None:
"""Create experiment + run. Safe to call from ``__init__``."""
try:
exp_id = self._get_or_create_experiment()
run_id = self._create_run(exp_id)
self._state = _RunState(run_id=run_id, experiment_id=exp_id)
if params:
self._log_params(params)
# Register cleanup
atexit.register(self._end_run)
logger.info("MLflow run started: %s (experiment=%s)", run_id, self._experiment_name)
except Exception:
logger.warning("MLflow init failed — metrics will not be logged", exc_info=True)
self._enabled = False
def log_request(self, **metrics: float) -> None:
"""Log one set of metrics for a single inference request.
Metrics are buffered and flushed every ``flush_every`` calls.
"""
if not self._enabled or not self._state:
return
with self._lock:
self._state.step += 1
step = self._state.step
ts = int(time.time() * 1000)
for key, value in metrics.items():
self._buffer.append(
{"key": key, "value": value, "timestamp": ts, "step": step}
)
if len(self._buffer) >= self._flush_every * len(metrics):
self._flush()
def flush(self) -> None:
"""Force-flush any buffered metrics."""
with self._lock:
self._flush()
# ------------------------------------------------------------------
# REST helpers
# ------------------------------------------------------------------
def _post(self, path: str, body: dict) -> dict:
url = f"{self._base}/api/2.0/mlflow/{path}"
data = json.dumps(body).encode()
req = urllib.request.Request(
url, data=data, headers={"Content-Type": "application/json"}, method="POST"
)
with urllib.request.urlopen(req, timeout=10) as resp:
return json.loads(resp.read().decode())
def _get_or_create_experiment(self) -> str:
try:
resp = self._post(
"experiments/get-by-name",
{"experiment_name": self._experiment_name},
)
return resp["experiment"]["experiment_id"]
except urllib.error.HTTPError:
resp = self._post(
"experiments/create",
{"name": self._experiment_name},
)
return resp["experiment_id"]
def _create_run(self, experiment_id: str) -> str:
tags = [
{"key": k, "value": v}
for k, v in {
"mlflow.runName": self._run_name,
"mlflow.source.type": "LOCAL",
"hostname": os.environ.get("HOSTNAME", "unknown"),
"namespace": os.environ.get("POD_NAMESPACE", "unknown"),
**self._tags,
}.items()
]
resp = self._post(
"runs/create",
{
"experiment_id": experiment_id,
"run_name": self._run_name,
"start_time": int(time.time() * 1000),
"tags": tags,
},
)
return resp["run"]["info"]["run_id"]
def _log_params(self, params: dict[str, str]) -> None:
if not self._state:
return
param_list = [{"key": k, "value": str(v)[:500]} for k, v in params.items()]
try:
self._post(
"runs/log-batch",
{"run_id": self._state.run_id, "params": param_list},
)
except Exception:
logger.debug("Failed to log params", exc_info=True)
def _flush(self) -> None:
"""Send buffered metrics in a single `log-batch` call."""
if not self._buffer or not self._state:
return
batch = self._buffer[:]
self._buffer.clear()
try:
self._post(
"runs/log-batch",
{"run_id": self._state.run_id, "metrics": batch},
)
except Exception:
logger.debug("Failed to flush %d metrics", len(batch), exc_info=True)
def _end_run(self) -> None:
"""Mark the MLflow run as FINISHED."""
if not self._state:
return
self._flush()
try:
self._post(
"runs/update",
{
"run_id": self._state.run_id,
"status": "FINISHED",
"end_time": int(time.time() * 1000),
},
)
logger.info("MLflow run %s ended", self._state.run_id)
except Exception:
logger.debug("Failed to end MLflow run", exc_info=True)