""" Inference Metrics Tracker for NATS Handlers Provides async-compatible MLflow logging for real-time inference metrics from chat-handler and voice-assistant services. Designed to integrate with the existing OpenTelemetry setup and complement OTel metrics with MLflow experiment tracking for longer-term analysis and model comparison. """ import asyncio import logging import time from collections import defaultdict from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass, field from functools import partial from typing import Any, Dict, List, Optional import mlflow from mlflow.tracking import MlflowClient from .client import MLflowConfig, ensure_experiment, get_mlflow_client logger = logging.getLogger(__name__) @dataclass class InferenceMetrics: """Metrics collected during an inference request.""" request_id: str user_id: Optional[str] = None session_id: Optional[str] = None # Timing metrics (in seconds) total_latency: float = 0.0 embedding_latency: float = 0.0 rag_search_latency: float = 0.0 rerank_latency: float = 0.0 llm_latency: float = 0.0 tts_latency: float = 0.0 stt_latency: float = 0.0 # Token/size metrics input_tokens: int = 0 output_tokens: int = 0 total_tokens: int = 0 prompt_length: int = 0 response_length: int = 0 # RAG metrics rag_enabled: bool = False rag_documents_retrieved: int = 0 rag_documents_used: int = 0 reranker_enabled: bool = False # Quality indicators is_streaming: bool = False is_premium: bool = False has_error: bool = False error_message: Optional[str] = None # Model information model_name: Optional[str] = None model_endpoint: Optional[str] = None # Timestamps timestamp: float = field(default_factory=time.time) def as_metrics_dict(self) -> Dict[str, float]: """Convert numeric fields to a metrics dictionary.""" return { "total_latency": self.total_latency, "embedding_latency": self.embedding_latency, "rag_search_latency": self.rag_search_latency, "rerank_latency": self.rerank_latency, "llm_latency": self.llm_latency, "tts_latency": self.tts_latency, "stt_latency": self.stt_latency, "input_tokens": float(self.input_tokens), "output_tokens": float(self.output_tokens), "total_tokens": float(self.total_tokens), "prompt_length": float(self.prompt_length), "response_length": float(self.response_length), "rag_documents_retrieved": float(self.rag_documents_retrieved), "rag_documents_used": float(self.rag_documents_used), } def as_params_dict(self) -> Dict[str, str]: """Convert configuration fields to a params dictionary.""" params = { "rag_enabled": str(self.rag_enabled), "reranker_enabled": str(self.reranker_enabled), "is_streaming": str(self.is_streaming), "is_premium": str(self.is_premium), } if self.model_name: params["model_name"] = self.model_name if self.model_endpoint: params["model_endpoint"] = self.model_endpoint return params class InferenceMetricsTracker: """ Async-compatible MLflow tracker for inference metrics. Uses batching and a background thread pool to avoid blocking the async event loop during MLflow calls. Example usage in chat-handler: class ChatHandler: def __init__(self): self.mlflow_tracker = InferenceMetricsTracker( service_name="chat-handler", experiment_name="chat-inference" ) async def setup(self): await self.mlflow_tracker.start() async def process_request(self, msg): metrics = InferenceMetrics(request_id=request_id) # Track timing start = time.time() # ... do embedding ... metrics.embedding_latency = time.time() - start # ... more processing ... # Log metrics (non-blocking) await self.mlflow_tracker.log_inference(metrics) async def shutdown(self): await self.mlflow_tracker.stop() """ def __init__( self, service_name: str, experiment_name: Optional[str] = None, tracking_uri: Optional[str] = None, batch_size: int = 50, flush_interval_seconds: float = 60.0, enable_batching: bool = True, max_workers: int = 2, ): """ Initialize the inference metrics tracker. Args: service_name: Name of the service (e.g., "chat-handler") experiment_name: MLflow experiment name (defaults to service_name) tracking_uri: Override default tracking URI batch_size: Number of metrics to batch before flushing flush_interval_seconds: Maximum time between flushes enable_batching: If False, log each request immediately max_workers: Number of thread pool workers for MLflow calls """ self.service_name = service_name self.experiment_name = experiment_name or f"{service_name}-inference" self.tracking_uri = tracking_uri self.batch_size = batch_size self.flush_interval = flush_interval_seconds self.enable_batching = enable_batching self.config = MLflowConfig() self._batch: List[InferenceMetrics] = [] self._batch_lock = asyncio.Lock() self._flush_task: Optional[asyncio.Task] = None self._executor = ThreadPoolExecutor(max_workers=max_workers) self._running = False self._client: Optional[MlflowClient] = None self._experiment_id: Optional[str] = None # Aggregate metrics for periodic logging self._aggregate_metrics: Dict[str, List[float]] = defaultdict(list) self._request_count = 0 self._error_count = 0 async def start(self) -> None: """Start the tracker and initialize MLflow connection.""" if self._running: return self._running = True # Initialize MLflow in thread pool to avoid blocking loop = asyncio.get_event_loop() await loop.run_in_executor( self._executor, self._init_mlflow ) if self.enable_batching: self._flush_task = asyncio.create_task(self._periodic_flush()) logger.info( f"InferenceMetricsTracker started for {self.service_name} " f"(experiment: {self.experiment_name})" ) def _init_mlflow(self) -> None: """Initialize MLflow client and experiment (runs in thread pool).""" self._client = get_mlflow_client( tracking_uri=self.tracking_uri, configure_global=True ) self._experiment_id = ensure_experiment( self.experiment_name, tags={ "service": self.service_name, "type": "inference-metrics", } ) async def stop(self) -> None: """Stop the tracker and flush remaining metrics.""" if not self._running: return self._running = False if self._flush_task: self._flush_task.cancel() try: await self._flush_task except asyncio.CancelledError: pass # Final flush await self._flush_batch() self._executor.shutdown(wait=True) logger.info(f"InferenceMetricsTracker stopped for {self.service_name}") async def log_inference(self, metrics: InferenceMetrics) -> None: """ Log inference metrics (non-blocking). Args: metrics: InferenceMetrics object with request data """ if not self._running: logger.warning("Tracker not running, skipping metrics") return self._request_count += 1 if metrics.has_error: self._error_count += 1 # Update aggregates for key, value in metrics.as_metrics_dict().items(): if value > 0: self._aggregate_metrics[key].append(value) if self.enable_batching: async with self._batch_lock: self._batch.append(metrics) if len(self._batch) >= self.batch_size: asyncio.create_task(self._flush_batch()) else: # Immediate logging in thread pool loop = asyncio.get_event_loop() await loop.run_in_executor( self._executor, partial(self._log_single_inference, metrics) ) async def _periodic_flush(self) -> None: """Periodically flush batched metrics.""" while self._running: await asyncio.sleep(self.flush_interval) await self._flush_batch() async def _flush_batch(self) -> None: """Flush the current batch of metrics to MLflow.""" async with self._batch_lock: if not self._batch: return batch = self._batch self._batch = [] # Log in thread pool loop = asyncio.get_event_loop() await loop.run_in_executor( self._executor, partial(self._log_batch, batch) ) def _log_single_inference(self, metrics: InferenceMetrics) -> None: """Log a single inference request to MLflow (runs in thread pool).""" try: with mlflow.start_run( experiment_id=self._experiment_id, run_name=f"inference-{metrics.request_id}", tags={ "service": self.service_name, "request_id": metrics.request_id, "type": "single-inference", } ): mlflow.log_params(metrics.as_params_dict()) mlflow.log_metrics(metrics.as_metrics_dict()) if metrics.user_id: mlflow.set_tag("user_id", metrics.user_id) if metrics.session_id: mlflow.set_tag("session_id", metrics.session_id) if metrics.has_error: mlflow.set_tag("has_error", "true") if metrics.error_message: mlflow.set_tag("error_message", metrics.error_message[:250]) except Exception as e: logger.error(f"Failed to log inference metrics: {e}") def _log_batch(self, batch: List[InferenceMetrics]) -> None: """Log a batch of inference metrics as aggregate statistics.""" if not batch: return try: # Calculate aggregates aggregates = self._calculate_aggregates(batch) run_name = f"batch-{self.service_name}-{int(time.time())}" with mlflow.start_run( experiment_id=self._experiment_id, run_name=run_name, tags={ "service": self.service_name, "type": "batch-inference", "batch_size": str(len(batch)), } ): # Log aggregate metrics mlflow.log_metrics(aggregates) # Log batch info mlflow.log_param("batch_size", len(batch)) mlflow.log_param("time_window_start", min(m.timestamp for m in batch)) mlflow.log_param("time_window_end", max(m.timestamp for m in batch)) # Log configuration breakdown rag_enabled_count = sum(1 for m in batch if m.rag_enabled) streaming_count = sum(1 for m in batch if m.is_streaming) premium_count = sum(1 for m in batch if m.is_premium) error_count = sum(1 for m in batch if m.has_error) mlflow.log_metrics({ "rag_enabled_pct": rag_enabled_count / len(batch) * 100, "streaming_pct": streaming_count / len(batch) * 100, "premium_pct": premium_count / len(batch) * 100, "error_rate": error_count / len(batch) * 100, }) # Log model distribution model_counts: Dict[str, int] = defaultdict(int) for m in batch: if m.model_name: model_counts[m.model_name] += 1 if model_counts: mlflow.log_dict( {"models": dict(model_counts)}, "model_distribution.json" ) logger.info(f"Logged batch of {len(batch)} inference metrics") except Exception as e: logger.error(f"Failed to log batch metrics: {e}") def _calculate_aggregates( self, batch: List[InferenceMetrics] ) -> Dict[str, float]: """Calculate aggregate statistics from a batch of metrics.""" import statistics aggregates = {} # Collect all numeric metrics metric_values: Dict[str, List[float]] = defaultdict(list) for m in batch: for key, value in m.as_metrics_dict().items(): if value > 0: metric_values[key].append(value) # Calculate statistics for each metric for key, values in metric_values.items(): if not values: continue aggregates[f"{key}_mean"] = statistics.mean(values) aggregates[f"{key}_min"] = min(values) aggregates[f"{key}_max"] = max(values) if len(values) >= 2: aggregates[f"{key}_p50"] = statistics.median(values) aggregates[f"{key}_stdev"] = statistics.stdev(values) if len(values) >= 4: sorted_vals = sorted(values) p95_idx = int(len(sorted_vals) * 0.95) p99_idx = int(len(sorted_vals) * 0.99) aggregates[f"{key}_p95"] = sorted_vals[p95_idx] aggregates[f"{key}_p99"] = sorted_vals[p99_idx] # Add counts aggregates["total_requests"] = float(len(batch)) return aggregates def get_stats(self) -> Dict[str, Any]: """Get current tracker statistics.""" return { "service_name": self.service_name, "experiment_name": self.experiment_name, "running": self._running, "total_requests": self._request_count, "error_count": self._error_count, "pending_batch_size": len(self._batch), "aggregate_metrics_count": len(self._aggregate_metrics), }