Files
mlflow/mlflow_utils/inference_tracker.py
Billy D. ca5bef9664
Some checks failed
CI / Lint (push) Successful in 1m46s
CI / Test (push) Successful in 1m44s
CI / Publish (push) Failing after 19s
CI / Notify (push) Successful in 1s
style: apply ruff format to all files
2026-02-13 11:05:26 -05:00

413 lines
14 KiB
Python

"""
Inference Metrics Tracker for NATS Handlers
Provides async-compatible MLflow logging for real-time inference
metrics from chat-handler and voice-assistant services.
Designed to integrate with the existing OpenTelemetry setup and
complement OTel metrics with MLflow experiment tracking for
longer-term analysis and model comparison.
"""
import asyncio
import logging
import time
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from functools import partial
from typing import Any, Dict, List, Optional
import mlflow
from mlflow.tracking import MlflowClient
from .client import MLflowConfig, ensure_experiment, get_mlflow_client
logger = logging.getLogger(__name__)
@dataclass
class InferenceMetrics:
"""Metrics collected during an inference request."""
request_id: str
user_id: Optional[str] = None
session_id: Optional[str] = None
# Timing metrics (in seconds)
total_latency: float = 0.0
embedding_latency: float = 0.0
rag_search_latency: float = 0.0
rerank_latency: float = 0.0
llm_latency: float = 0.0
tts_latency: float = 0.0
stt_latency: float = 0.0
# Token/size metrics
input_tokens: int = 0
output_tokens: int = 0
total_tokens: int = 0
prompt_length: int = 0
response_length: int = 0
# RAG metrics
rag_enabled: bool = False
rag_documents_retrieved: int = 0
rag_documents_used: int = 0
reranker_enabled: bool = False
# Quality indicators
is_streaming: bool = False
is_premium: bool = False
has_error: bool = False
error_message: Optional[str] = None
# Model information
model_name: Optional[str] = None
model_endpoint: Optional[str] = None
# Timestamps
timestamp: float = field(default_factory=time.time)
def as_metrics_dict(self) -> Dict[str, float]:
"""Convert numeric fields to a metrics dictionary."""
return {
"total_latency": self.total_latency,
"embedding_latency": self.embedding_latency,
"rag_search_latency": self.rag_search_latency,
"rerank_latency": self.rerank_latency,
"llm_latency": self.llm_latency,
"tts_latency": self.tts_latency,
"stt_latency": self.stt_latency,
"input_tokens": float(self.input_tokens),
"output_tokens": float(self.output_tokens),
"total_tokens": float(self.total_tokens),
"prompt_length": float(self.prompt_length),
"response_length": float(self.response_length),
"rag_documents_retrieved": float(self.rag_documents_retrieved),
"rag_documents_used": float(self.rag_documents_used),
}
def as_params_dict(self) -> Dict[str, str]:
"""Convert configuration fields to a params dictionary."""
params = {
"rag_enabled": str(self.rag_enabled),
"reranker_enabled": str(self.reranker_enabled),
"is_streaming": str(self.is_streaming),
"is_premium": str(self.is_premium),
}
if self.model_name:
params["model_name"] = self.model_name
if self.model_endpoint:
params["model_endpoint"] = self.model_endpoint
return params
class InferenceMetricsTracker:
"""
Async-compatible MLflow tracker for inference metrics.
Uses batching and a background thread pool to avoid blocking
the async event loop during MLflow calls.
Example usage in chat-handler:
class ChatHandler:
def __init__(self):
self.mlflow_tracker = InferenceMetricsTracker(
service_name="chat-handler",
experiment_name="chat-inference"
)
async def setup(self):
await self.mlflow_tracker.start()
async def process_request(self, msg):
metrics = InferenceMetrics(request_id=request_id)
# Track timing
start = time.time()
# ... do embedding ...
metrics.embedding_latency = time.time() - start
# ... more processing ...
# Log metrics (non-blocking)
await self.mlflow_tracker.log_inference(metrics)
async def shutdown(self):
await self.mlflow_tracker.stop()
"""
def __init__(
self,
service_name: str,
experiment_name: Optional[str] = None,
tracking_uri: Optional[str] = None,
batch_size: int = 50,
flush_interval_seconds: float = 60.0,
enable_batching: bool = True,
max_workers: int = 2,
):
"""
Initialize the inference metrics tracker.
Args:
service_name: Name of the service (e.g., "chat-handler")
experiment_name: MLflow experiment name (defaults to service_name)
tracking_uri: Override default tracking URI
batch_size: Number of metrics to batch before flushing
flush_interval_seconds: Maximum time between flushes
enable_batching: If False, log each request immediately
max_workers: Number of thread pool workers for MLflow calls
"""
self.service_name = service_name
self.experiment_name = experiment_name or f"{service_name}-inference"
self.tracking_uri = tracking_uri
self.batch_size = batch_size
self.flush_interval = flush_interval_seconds
self.enable_batching = enable_batching
self.config = MLflowConfig()
self._batch: List[InferenceMetrics] = []
self._batch_lock = asyncio.Lock()
self._flush_task: Optional[asyncio.Task] = None
self._executor = ThreadPoolExecutor(max_workers=max_workers)
self._running = False
self._client: Optional[MlflowClient] = None
self._experiment_id: Optional[str] = None
# Aggregate metrics for periodic logging
self._aggregate_metrics: Dict[str, List[float]] = defaultdict(list)
self._request_count = 0
self._error_count = 0
async def start(self) -> None:
"""Start the tracker and initialize MLflow connection."""
if self._running:
return
self._running = True
# Initialize MLflow in thread pool to avoid blocking
loop = asyncio.get_event_loop()
await loop.run_in_executor(self._executor, self._init_mlflow)
if self.enable_batching:
self._flush_task = asyncio.create_task(self._periodic_flush())
logger.info(f"InferenceMetricsTracker started for {self.service_name} (experiment: {self.experiment_name})")
def _init_mlflow(self) -> None:
"""Initialize MLflow client and experiment (runs in thread pool)."""
self._client = get_mlflow_client(tracking_uri=self.tracking_uri, configure_global=True)
self._experiment_id = ensure_experiment(
self.experiment_name,
tags={
"service": self.service_name,
"type": "inference-metrics",
},
)
async def stop(self) -> None:
"""Stop the tracker and flush remaining metrics."""
if not self._running:
return
self._running = False
if self._flush_task:
self._flush_task.cancel()
try:
await self._flush_task
except asyncio.CancelledError:
pass
# Final flush
await self._flush_batch()
self._executor.shutdown(wait=True)
logger.info(f"InferenceMetricsTracker stopped for {self.service_name}")
async def log_inference(self, metrics: InferenceMetrics) -> None:
"""
Log inference metrics (non-blocking).
Args:
metrics: InferenceMetrics object with request data
"""
if not self._running:
logger.warning("Tracker not running, skipping metrics")
return
self._request_count += 1
if metrics.has_error:
self._error_count += 1
# Update aggregates
for key, value in metrics.as_metrics_dict().items():
if value > 0:
self._aggregate_metrics[key].append(value)
if self.enable_batching:
async with self._batch_lock:
self._batch.append(metrics)
if len(self._batch) >= self.batch_size:
asyncio.create_task(self._flush_batch())
else:
# Immediate logging in thread pool
loop = asyncio.get_event_loop()
await loop.run_in_executor(self._executor, partial(self._log_single_inference, metrics))
async def _periodic_flush(self) -> None:
"""Periodically flush batched metrics."""
while self._running:
await asyncio.sleep(self.flush_interval)
await self._flush_batch()
async def _flush_batch(self) -> None:
"""Flush the current batch of metrics to MLflow."""
async with self._batch_lock:
if not self._batch:
return
batch = self._batch
self._batch = []
# Log in thread pool
loop = asyncio.get_event_loop()
await loop.run_in_executor(self._executor, partial(self._log_batch, batch))
def _log_single_inference(self, metrics: InferenceMetrics) -> None:
"""Log a single inference request to MLflow (runs in thread pool)."""
try:
with mlflow.start_run(
experiment_id=self._experiment_id,
run_name=f"inference-{metrics.request_id}",
tags={
"service": self.service_name,
"request_id": metrics.request_id,
"type": "single-inference",
},
):
mlflow.log_params(metrics.as_params_dict())
mlflow.log_metrics(metrics.as_metrics_dict())
if metrics.user_id:
mlflow.set_tag("user_id", metrics.user_id)
if metrics.session_id:
mlflow.set_tag("session_id", metrics.session_id)
if metrics.has_error:
mlflow.set_tag("has_error", "true")
if metrics.error_message:
mlflow.set_tag("error_message", metrics.error_message[:250])
except Exception as e:
logger.error(f"Failed to log inference metrics: {e}")
def _log_batch(self, batch: List[InferenceMetrics]) -> None:
"""Log a batch of inference metrics as aggregate statistics."""
if not batch:
return
try:
# Calculate aggregates
aggregates = self._calculate_aggregates(batch)
run_name = f"batch-{self.service_name}-{int(time.time())}"
with mlflow.start_run(
experiment_id=self._experiment_id,
run_name=run_name,
tags={
"service": self.service_name,
"type": "batch-inference",
"batch_size": str(len(batch)),
},
):
# Log aggregate metrics
mlflow.log_metrics(aggregates)
# Log batch info
mlflow.log_param("batch_size", len(batch))
mlflow.log_param("time_window_start", min(m.timestamp for m in batch))
mlflow.log_param("time_window_end", max(m.timestamp for m in batch))
# Log configuration breakdown
rag_enabled_count = sum(1 for m in batch if m.rag_enabled)
streaming_count = sum(1 for m in batch if m.is_streaming)
premium_count = sum(1 for m in batch if m.is_premium)
error_count = sum(1 for m in batch if m.has_error)
mlflow.log_metrics(
{
"rag_enabled_pct": rag_enabled_count / len(batch) * 100,
"streaming_pct": streaming_count / len(batch) * 100,
"premium_pct": premium_count / len(batch) * 100,
"error_rate": error_count / len(batch) * 100,
}
)
# Log model distribution
model_counts: Dict[str, int] = defaultdict(int)
for m in batch:
if m.model_name:
model_counts[m.model_name] += 1
if model_counts:
mlflow.log_dict({"models": dict(model_counts)}, "model_distribution.json")
logger.info(f"Logged batch of {len(batch)} inference metrics")
except Exception as e:
logger.error(f"Failed to log batch metrics: {e}")
def _calculate_aggregates(self, batch: List[InferenceMetrics]) -> Dict[str, float]:
"""Calculate aggregate statistics from a batch of metrics."""
import statistics
aggregates = {}
# Collect all numeric metrics
metric_values: Dict[str, List[float]] = defaultdict(list)
for m in batch:
for key, value in m.as_metrics_dict().items():
if value > 0:
metric_values[key].append(value)
# Calculate statistics for each metric
for key, values in metric_values.items():
if not values:
continue
aggregates[f"{key}_mean"] = statistics.mean(values)
aggregates[f"{key}_min"] = min(values)
aggregates[f"{key}_max"] = max(values)
if len(values) >= 2:
aggregates[f"{key}_p50"] = statistics.median(values)
aggregates[f"{key}_stdev"] = statistics.stdev(values)
if len(values) >= 4:
sorted_vals = sorted(values)
p95_idx = int(len(sorted_vals) * 0.95)
p99_idx = int(len(sorted_vals) * 0.99)
aggregates[f"{key}_p95"] = sorted_vals[p95_idx]
aggregates[f"{key}_p99"] = sorted_vals[p99_idx]
# Add counts
aggregates["total_requests"] = float(len(batch))
return aggregates
def get_stats(self) -> Dict[str, Any]:
"""Get current tracker statistics."""
return {
"service_name": self.service_name,
"experiment_name": self.experiment_name,
"running": self._running,
"total_requests": self._request_count,
"error_count": self._error_count,
"pending_batch_size": len(self._batch),
"aggregate_metrics_count": len(self._aggregate_metrics),
}