feat: Add MLflow integration utilities
- client: Connection management and helpers - tracker: General experiment tracking - inference_tracker: Async metrics for NATS handlers - model_registry: Model registration with KServe metadata - kfp_components: Kubeflow Pipeline components - experiment_comparison: Run comparison tools - cli: Command-line interface
This commit is contained in:
431
mlflow_utils/inference_tracker.py
Normal file
431
mlflow_utils/inference_tracker.py
Normal file
@@ -0,0 +1,431 @@
|
||||
"""
|
||||
Inference Metrics Tracker for NATS Handlers
|
||||
|
||||
Provides async-compatible MLflow logging for real-time inference
|
||||
metrics from chat-handler and voice-assistant services.
|
||||
|
||||
Designed to integrate with the existing OpenTelemetry setup and
|
||||
complement OTel metrics with MLflow experiment tracking for
|
||||
longer-term analysis and model comparison.
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Optional, Dict, Any, List
|
||||
from dataclasses import dataclass, field
|
||||
from collections import defaultdict
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from functools import partial
|
||||
|
||||
import mlflow
|
||||
from mlflow.tracking import MlflowClient
|
||||
|
||||
from .client import get_mlflow_client, ensure_experiment, MLflowConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class InferenceMetrics:
|
||||
"""Metrics collected during an inference request."""
|
||||
request_id: str
|
||||
user_id: Optional[str] = None
|
||||
session_id: Optional[str] = None
|
||||
|
||||
# Timing metrics (in seconds)
|
||||
total_latency: float = 0.0
|
||||
embedding_latency: float = 0.0
|
||||
rag_search_latency: float = 0.0
|
||||
rerank_latency: float = 0.0
|
||||
llm_latency: float = 0.0
|
||||
tts_latency: float = 0.0
|
||||
stt_latency: float = 0.0
|
||||
|
||||
# Token/size metrics
|
||||
input_tokens: int = 0
|
||||
output_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
prompt_length: int = 0
|
||||
response_length: int = 0
|
||||
|
||||
# RAG metrics
|
||||
rag_enabled: bool = False
|
||||
rag_documents_retrieved: int = 0
|
||||
rag_documents_used: int = 0
|
||||
reranker_enabled: bool = False
|
||||
|
||||
# Quality indicators
|
||||
is_streaming: bool = False
|
||||
is_premium: bool = False
|
||||
has_error: bool = False
|
||||
error_message: Optional[str] = None
|
||||
|
||||
# Model information
|
||||
model_name: Optional[str] = None
|
||||
model_endpoint: Optional[str] = None
|
||||
|
||||
# Timestamps
|
||||
timestamp: float = field(default_factory=time.time)
|
||||
|
||||
def as_metrics_dict(self) -> Dict[str, float]:
|
||||
"""Convert numeric fields to a metrics dictionary."""
|
||||
return {
|
||||
"total_latency": self.total_latency,
|
||||
"embedding_latency": self.embedding_latency,
|
||||
"rag_search_latency": self.rag_search_latency,
|
||||
"rerank_latency": self.rerank_latency,
|
||||
"llm_latency": self.llm_latency,
|
||||
"tts_latency": self.tts_latency,
|
||||
"stt_latency": self.stt_latency,
|
||||
"input_tokens": float(self.input_tokens),
|
||||
"output_tokens": float(self.output_tokens),
|
||||
"total_tokens": float(self.total_tokens),
|
||||
"prompt_length": float(self.prompt_length),
|
||||
"response_length": float(self.response_length),
|
||||
"rag_documents_retrieved": float(self.rag_documents_retrieved),
|
||||
"rag_documents_used": float(self.rag_documents_used),
|
||||
}
|
||||
|
||||
def as_params_dict(self) -> Dict[str, str]:
|
||||
"""Convert configuration fields to a params dictionary."""
|
||||
params = {
|
||||
"rag_enabled": str(self.rag_enabled),
|
||||
"reranker_enabled": str(self.reranker_enabled),
|
||||
"is_streaming": str(self.is_streaming),
|
||||
"is_premium": str(self.is_premium),
|
||||
}
|
||||
if self.model_name:
|
||||
params["model_name"] = self.model_name
|
||||
if self.model_endpoint:
|
||||
params["model_endpoint"] = self.model_endpoint
|
||||
return params
|
||||
|
||||
|
||||
class InferenceMetricsTracker:
|
||||
"""
|
||||
Async-compatible MLflow tracker for inference metrics.
|
||||
|
||||
Uses batching and a background thread pool to avoid blocking
|
||||
the async event loop during MLflow calls.
|
||||
|
||||
Example usage in chat-handler:
|
||||
|
||||
class ChatHandler:
|
||||
def __init__(self):
|
||||
self.mlflow_tracker = InferenceMetricsTracker(
|
||||
service_name="chat-handler",
|
||||
experiment_name="chat-inference"
|
||||
)
|
||||
|
||||
async def setup(self):
|
||||
await self.mlflow_tracker.start()
|
||||
|
||||
async def process_request(self, msg):
|
||||
metrics = InferenceMetrics(request_id=request_id)
|
||||
|
||||
# Track timing
|
||||
start = time.time()
|
||||
# ... do embedding ...
|
||||
metrics.embedding_latency = time.time() - start
|
||||
|
||||
# ... more processing ...
|
||||
|
||||
# Log metrics (non-blocking)
|
||||
await self.mlflow_tracker.log_inference(metrics)
|
||||
|
||||
async def shutdown(self):
|
||||
await self.mlflow_tracker.stop()
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
service_name: str,
|
||||
experiment_name: Optional[str] = None,
|
||||
tracking_uri: Optional[str] = None,
|
||||
batch_size: int = 50,
|
||||
flush_interval_seconds: float = 60.0,
|
||||
enable_batching: bool = True,
|
||||
max_workers: int = 2,
|
||||
):
|
||||
"""
|
||||
Initialize the inference metrics tracker.
|
||||
|
||||
Args:
|
||||
service_name: Name of the service (e.g., "chat-handler")
|
||||
experiment_name: MLflow experiment name (defaults to service_name)
|
||||
tracking_uri: Override default tracking URI
|
||||
batch_size: Number of metrics to batch before flushing
|
||||
flush_interval_seconds: Maximum time between flushes
|
||||
enable_batching: If False, log each request immediately
|
||||
max_workers: Number of thread pool workers for MLflow calls
|
||||
"""
|
||||
self.service_name = service_name
|
||||
self.experiment_name = experiment_name or f"{service_name}-inference"
|
||||
self.tracking_uri = tracking_uri
|
||||
self.batch_size = batch_size
|
||||
self.flush_interval = flush_interval_seconds
|
||||
self.enable_batching = enable_batching
|
||||
|
||||
self.config = MLflowConfig()
|
||||
self._batch: List[InferenceMetrics] = []
|
||||
self._batch_lock = asyncio.Lock()
|
||||
self._flush_task: Optional[asyncio.Task] = None
|
||||
self._executor = ThreadPoolExecutor(max_workers=max_workers)
|
||||
self._running = False
|
||||
self._client: Optional[MlflowClient] = None
|
||||
self._experiment_id: Optional[str] = None
|
||||
|
||||
# Aggregate metrics for periodic logging
|
||||
self._aggregate_metrics: Dict[str, List[float]] = defaultdict(list)
|
||||
self._request_count = 0
|
||||
self._error_count = 0
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the tracker and initialize MLflow connection."""
|
||||
if self._running:
|
||||
return
|
||||
|
||||
self._running = True
|
||||
|
||||
# Initialize MLflow in thread pool to avoid blocking
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(
|
||||
self._executor,
|
||||
self._init_mlflow
|
||||
)
|
||||
|
||||
if self.enable_batching:
|
||||
self._flush_task = asyncio.create_task(self._periodic_flush())
|
||||
|
||||
logger.info(
|
||||
f"InferenceMetricsTracker started for {self.service_name} "
|
||||
f"(experiment: {self.experiment_name})"
|
||||
)
|
||||
|
||||
def _init_mlflow(self) -> None:
|
||||
"""Initialize MLflow client and experiment (runs in thread pool)."""
|
||||
self._client = get_mlflow_client(
|
||||
tracking_uri=self.tracking_uri,
|
||||
configure_global=True
|
||||
)
|
||||
self._experiment_id = ensure_experiment(
|
||||
self.experiment_name,
|
||||
tags={
|
||||
"service": self.service_name,
|
||||
"type": "inference-metrics",
|
||||
}
|
||||
)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the tracker and flush remaining metrics."""
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
self._running = False
|
||||
|
||||
if self._flush_task:
|
||||
self._flush_task.cancel()
|
||||
try:
|
||||
await self._flush_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Final flush
|
||||
await self._flush_batch()
|
||||
|
||||
self._executor.shutdown(wait=True)
|
||||
logger.info(f"InferenceMetricsTracker stopped for {self.service_name}")
|
||||
|
||||
async def log_inference(self, metrics: InferenceMetrics) -> None:
|
||||
"""
|
||||
Log inference metrics (non-blocking).
|
||||
|
||||
Args:
|
||||
metrics: InferenceMetrics object with request data
|
||||
"""
|
||||
if not self._running:
|
||||
logger.warning("Tracker not running, skipping metrics")
|
||||
return
|
||||
|
||||
self._request_count += 1
|
||||
if metrics.has_error:
|
||||
self._error_count += 1
|
||||
|
||||
# Update aggregates
|
||||
for key, value in metrics.as_metrics_dict().items():
|
||||
if value > 0:
|
||||
self._aggregate_metrics[key].append(value)
|
||||
|
||||
if self.enable_batching:
|
||||
async with self._batch_lock:
|
||||
self._batch.append(metrics)
|
||||
if len(self._batch) >= self.batch_size:
|
||||
asyncio.create_task(self._flush_batch())
|
||||
else:
|
||||
# Immediate logging in thread pool
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(
|
||||
self._executor,
|
||||
partial(self._log_single_inference, metrics)
|
||||
)
|
||||
|
||||
async def _periodic_flush(self) -> None:
|
||||
"""Periodically flush batched metrics."""
|
||||
while self._running:
|
||||
await asyncio.sleep(self.flush_interval)
|
||||
await self._flush_batch()
|
||||
|
||||
async def _flush_batch(self) -> None:
|
||||
"""Flush the current batch of metrics to MLflow."""
|
||||
async with self._batch_lock:
|
||||
if not self._batch:
|
||||
return
|
||||
|
||||
batch = self._batch
|
||||
self._batch = []
|
||||
|
||||
# Log in thread pool
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(
|
||||
self._executor,
|
||||
partial(self._log_batch, batch)
|
||||
)
|
||||
|
||||
def _log_single_inference(self, metrics: InferenceMetrics) -> None:
|
||||
"""Log a single inference request to MLflow (runs in thread pool)."""
|
||||
try:
|
||||
with mlflow.start_run(
|
||||
experiment_id=self._experiment_id,
|
||||
run_name=f"inference-{metrics.request_id}",
|
||||
tags={
|
||||
"service": self.service_name,
|
||||
"request_id": metrics.request_id,
|
||||
"type": "single-inference",
|
||||
}
|
||||
):
|
||||
mlflow.log_params(metrics.as_params_dict())
|
||||
mlflow.log_metrics(metrics.as_metrics_dict())
|
||||
|
||||
if metrics.user_id:
|
||||
mlflow.set_tag("user_id", metrics.user_id)
|
||||
if metrics.session_id:
|
||||
mlflow.set_tag("session_id", metrics.session_id)
|
||||
if metrics.has_error:
|
||||
mlflow.set_tag("has_error", "true")
|
||||
if metrics.error_message:
|
||||
mlflow.set_tag("error_message", metrics.error_message[:250])
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to log inference metrics: {e}")
|
||||
|
||||
def _log_batch(self, batch: List[InferenceMetrics]) -> None:
|
||||
"""Log a batch of inference metrics as aggregate statistics."""
|
||||
if not batch:
|
||||
return
|
||||
|
||||
try:
|
||||
# Calculate aggregates
|
||||
aggregates = self._calculate_aggregates(batch)
|
||||
|
||||
run_name = f"batch-{self.service_name}-{int(time.time())}"
|
||||
|
||||
with mlflow.start_run(
|
||||
experiment_id=self._experiment_id,
|
||||
run_name=run_name,
|
||||
tags={
|
||||
"service": self.service_name,
|
||||
"type": "batch-inference",
|
||||
"batch_size": str(len(batch)),
|
||||
}
|
||||
):
|
||||
# Log aggregate metrics
|
||||
mlflow.log_metrics(aggregates)
|
||||
|
||||
# Log batch info
|
||||
mlflow.log_param("batch_size", len(batch))
|
||||
mlflow.log_param("time_window_start", min(m.timestamp for m in batch))
|
||||
mlflow.log_param("time_window_end", max(m.timestamp for m in batch))
|
||||
|
||||
# Log configuration breakdown
|
||||
rag_enabled_count = sum(1 for m in batch if m.rag_enabled)
|
||||
streaming_count = sum(1 for m in batch if m.is_streaming)
|
||||
premium_count = sum(1 for m in batch if m.is_premium)
|
||||
error_count = sum(1 for m in batch if m.has_error)
|
||||
|
||||
mlflow.log_metrics({
|
||||
"rag_enabled_pct": rag_enabled_count / len(batch) * 100,
|
||||
"streaming_pct": streaming_count / len(batch) * 100,
|
||||
"premium_pct": premium_count / len(batch) * 100,
|
||||
"error_rate": error_count / len(batch) * 100,
|
||||
})
|
||||
|
||||
# Log model distribution
|
||||
model_counts: Dict[str, int] = defaultdict(int)
|
||||
for m in batch:
|
||||
if m.model_name:
|
||||
model_counts[m.model_name] += 1
|
||||
|
||||
if model_counts:
|
||||
mlflow.log_dict(
|
||||
{"models": dict(model_counts)},
|
||||
"model_distribution.json"
|
||||
)
|
||||
|
||||
logger.info(f"Logged batch of {len(batch)} inference metrics")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to log batch metrics: {e}")
|
||||
|
||||
def _calculate_aggregates(
|
||||
self,
|
||||
batch: List[InferenceMetrics]
|
||||
) -> Dict[str, float]:
|
||||
"""Calculate aggregate statistics from a batch of metrics."""
|
||||
import statistics
|
||||
|
||||
aggregates = {}
|
||||
|
||||
# Collect all numeric metrics
|
||||
metric_values: Dict[str, List[float]] = defaultdict(list)
|
||||
for m in batch:
|
||||
for key, value in m.as_metrics_dict().items():
|
||||
if value > 0:
|
||||
metric_values[key].append(value)
|
||||
|
||||
# Calculate statistics for each metric
|
||||
for key, values in metric_values.items():
|
||||
if not values:
|
||||
continue
|
||||
|
||||
aggregates[f"{key}_mean"] = statistics.mean(values)
|
||||
aggregates[f"{key}_min"] = min(values)
|
||||
aggregates[f"{key}_max"] = max(values)
|
||||
|
||||
if len(values) >= 2:
|
||||
aggregates[f"{key}_p50"] = statistics.median(values)
|
||||
aggregates[f"{key}_stdev"] = statistics.stdev(values)
|
||||
|
||||
if len(values) >= 4:
|
||||
sorted_vals = sorted(values)
|
||||
p95_idx = int(len(sorted_vals) * 0.95)
|
||||
p99_idx = int(len(sorted_vals) * 0.99)
|
||||
aggregates[f"{key}_p95"] = sorted_vals[p95_idx]
|
||||
aggregates[f"{key}_p99"] = sorted_vals[p99_idx]
|
||||
|
||||
# Add counts
|
||||
aggregates["total_requests"] = float(len(batch))
|
||||
|
||||
return aggregates
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get current tracker statistics."""
|
||||
return {
|
||||
"service_name": self.service_name,
|
||||
"experiment_name": self.experiment_name,
|
||||
"running": self._running,
|
||||
"total_requests": self._request_count,
|
||||
"error_count": self._error_count,
|
||||
"pending_batch_size": len(self._batch),
|
||||
"aggregate_metrics_count": len(self._aggregate_metrics),
|
||||
}
|
||||
Reference in New Issue
Block a user