fix: resolve all ruff lint errors
Some checks failed
CI / Test (push) Successful in 1m46s
CI / Lint (push) Failing after 1m49s
CI / Publish (push) Has been skipped
CI / Notify (push) Successful in 2s

This commit is contained in:
2026-02-13 10:57:57 -05:00
parent 6bcf84549c
commit 1c841729a0
9 changed files with 456 additions and 464 deletions

View File

@@ -9,20 +9,19 @@ complement OTel metrics with MLflow experiment tracking for
longer-term analysis and model comparison.
"""
import os
import time
import asyncio
import logging
from typing import Optional, Dict, Any, List
from dataclasses import dataclass, field
import time
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from functools import partial
from typing import Any, Dict, List, Optional
import mlflow
from mlflow.tracking import MlflowClient
from .client import get_mlflow_client, ensure_experiment, MLflowConfig
from .client import MLflowConfig, ensure_experiment, get_mlflow_client
logger = logging.getLogger(__name__)
@@ -33,7 +32,7 @@ class InferenceMetrics:
request_id: str
user_id: Optional[str] = None
session_id: Optional[str] = None
# Timing metrics (in seconds)
total_latency: float = 0.0
embedding_latency: float = 0.0
@@ -42,33 +41,33 @@ class InferenceMetrics:
llm_latency: float = 0.0
tts_latency: float = 0.0
stt_latency: float = 0.0
# Token/size metrics
input_tokens: int = 0
output_tokens: int = 0
total_tokens: int = 0
prompt_length: int = 0
response_length: int = 0
# RAG metrics
rag_enabled: bool = False
rag_documents_retrieved: int = 0
rag_documents_used: int = 0
reranker_enabled: bool = False
# Quality indicators
is_streaming: bool = False
is_premium: bool = False
has_error: bool = False
error_message: Optional[str] = None
# Model information
model_name: Optional[str] = None
model_endpoint: Optional[str] = None
# Timestamps
timestamp: float = field(default_factory=time.time)
def as_metrics_dict(self) -> Dict[str, float]:
"""Convert numeric fields to a metrics dictionary."""
return {
@@ -87,7 +86,7 @@ class InferenceMetrics:
"rag_documents_retrieved": float(self.rag_documents_retrieved),
"rag_documents_used": float(self.rag_documents_used),
}
def as_params_dict(self) -> Dict[str, str]:
"""Convert configuration fields to a params dictionary."""
params = {
@@ -106,39 +105,39 @@ class InferenceMetrics:
class InferenceMetricsTracker:
"""
Async-compatible MLflow tracker for inference metrics.
Uses batching and a background thread pool to avoid blocking
the async event loop during MLflow calls.
Example usage in chat-handler:
class ChatHandler:
def __init__(self):
self.mlflow_tracker = InferenceMetricsTracker(
service_name="chat-handler",
experiment_name="chat-inference"
)
async def setup(self):
await self.mlflow_tracker.start()
async def process_request(self, msg):
metrics = InferenceMetrics(request_id=request_id)
# Track timing
start = time.time()
# ... do embedding ...
metrics.embedding_latency = time.time() - start
# ... more processing ...
# Log metrics (non-blocking)
await self.mlflow_tracker.log_inference(metrics)
async def shutdown(self):
await self.mlflow_tracker.stop()
"""
def __init__(
self,
service_name: str,
@@ -151,7 +150,7 @@ class InferenceMetricsTracker:
):
"""
Initialize the inference metrics tracker.
Args:
service_name: Name of the service (e.g., "chat-handler")
experiment_name: MLflow experiment name (defaults to service_name)
@@ -167,7 +166,7 @@ class InferenceMetricsTracker:
self.batch_size = batch_size
self.flush_interval = flush_interval_seconds
self.enable_batching = enable_batching
self.config = MLflowConfig()
self._batch: List[InferenceMetrics] = []
self._batch_lock = asyncio.Lock()
@@ -176,34 +175,34 @@ class InferenceMetricsTracker:
self._running = False
self._client: Optional[MlflowClient] = None
self._experiment_id: Optional[str] = None
# Aggregate metrics for periodic logging
self._aggregate_metrics: Dict[str, List[float]] = defaultdict(list)
self._request_count = 0
self._error_count = 0
async def start(self) -> None:
"""Start the tracker and initialize MLflow connection."""
if self._running:
return
self._running = True
# Initialize MLflow in thread pool to avoid blocking
loop = asyncio.get_event_loop()
await loop.run_in_executor(
self._executor,
self._init_mlflow
)
if self.enable_batching:
self._flush_task = asyncio.create_task(self._periodic_flush())
logger.info(
f"InferenceMetricsTracker started for {self.service_name} "
f"(experiment: {self.experiment_name})"
)
def _init_mlflow(self) -> None:
"""Initialize MLflow client and experiment (runs in thread pool)."""
self._client = get_mlflow_client(
@@ -217,47 +216,47 @@ class InferenceMetricsTracker:
"type": "inference-metrics",
}
)
async def stop(self) -> None:
"""Stop the tracker and flush remaining metrics."""
if not self._running:
return
self._running = False
if self._flush_task:
self._flush_task.cancel()
try:
await self._flush_task
except asyncio.CancelledError:
pass
# Final flush
await self._flush_batch()
self._executor.shutdown(wait=True)
logger.info(f"InferenceMetricsTracker stopped for {self.service_name}")
async def log_inference(self, metrics: InferenceMetrics) -> None:
"""
Log inference metrics (non-blocking).
Args:
metrics: InferenceMetrics object with request data
"""
if not self._running:
logger.warning("Tracker not running, skipping metrics")
return
self._request_count += 1
if metrics.has_error:
self._error_count += 1
# Update aggregates
for key, value in metrics.as_metrics_dict().items():
if value > 0:
self._aggregate_metrics[key].append(value)
if self.enable_batching:
async with self._batch_lock:
self._batch.append(metrics)
@@ -270,29 +269,29 @@ class InferenceMetricsTracker:
self._executor,
partial(self._log_single_inference, metrics)
)
async def _periodic_flush(self) -> None:
"""Periodically flush batched metrics."""
while self._running:
await asyncio.sleep(self.flush_interval)
await self._flush_batch()
async def _flush_batch(self) -> None:
"""Flush the current batch of metrics to MLflow."""
async with self._batch_lock:
if not self._batch:
return
batch = self._batch
self._batch = []
# Log in thread pool
loop = asyncio.get_event_loop()
await loop.run_in_executor(
self._executor,
partial(self._log_batch, batch)
)
def _log_single_inference(self, metrics: InferenceMetrics) -> None:
"""Log a single inference request to MLflow (runs in thread pool)."""
try:
@@ -307,7 +306,7 @@ class InferenceMetricsTracker:
):
mlflow.log_params(metrics.as_params_dict())
mlflow.log_metrics(metrics.as_metrics_dict())
if metrics.user_id:
mlflow.set_tag("user_id", metrics.user_id)
if metrics.session_id:
@@ -318,18 +317,18 @@ class InferenceMetricsTracker:
mlflow.set_tag("error_message", metrics.error_message[:250])
except Exception as e:
logger.error(f"Failed to log inference metrics: {e}")
def _log_batch(self, batch: List[InferenceMetrics]) -> None:
"""Log a batch of inference metrics as aggregate statistics."""
if not batch:
return
try:
# Calculate aggregates
aggregates = self._calculate_aggregates(batch)
run_name = f"batch-{self.service_name}-{int(time.time())}"
with mlflow.start_run(
experiment_id=self._experiment_id,
run_name=run_name,
@@ -341,83 +340,83 @@ class InferenceMetricsTracker:
):
# Log aggregate metrics
mlflow.log_metrics(aggregates)
# Log batch info
mlflow.log_param("batch_size", len(batch))
mlflow.log_param("time_window_start", min(m.timestamp for m in batch))
mlflow.log_param("time_window_end", max(m.timestamp for m in batch))
# Log configuration breakdown
rag_enabled_count = sum(1 for m in batch if m.rag_enabled)
streaming_count = sum(1 for m in batch if m.is_streaming)
premium_count = sum(1 for m in batch if m.is_premium)
error_count = sum(1 for m in batch if m.has_error)
mlflow.log_metrics({
"rag_enabled_pct": rag_enabled_count / len(batch) * 100,
"streaming_pct": streaming_count / len(batch) * 100,
"premium_pct": premium_count / len(batch) * 100,
"error_rate": error_count / len(batch) * 100,
})
# Log model distribution
model_counts: Dict[str, int] = defaultdict(int)
for m in batch:
if m.model_name:
model_counts[m.model_name] += 1
if model_counts:
mlflow.log_dict(
{"models": dict(model_counts)},
"model_distribution.json"
)
logger.info(f"Logged batch of {len(batch)} inference metrics")
except Exception as e:
logger.error(f"Failed to log batch metrics: {e}")
def _calculate_aggregates(
self,
batch: List[InferenceMetrics]
) -> Dict[str, float]:
"""Calculate aggregate statistics from a batch of metrics."""
import statistics
aggregates = {}
# Collect all numeric metrics
metric_values: Dict[str, List[float]] = defaultdict(list)
for m in batch:
for key, value in m.as_metrics_dict().items():
if value > 0:
metric_values[key].append(value)
# Calculate statistics for each metric
for key, values in metric_values.items():
if not values:
continue
aggregates[f"{key}_mean"] = statistics.mean(values)
aggregates[f"{key}_min"] = min(values)
aggregates[f"{key}_max"] = max(values)
if len(values) >= 2:
aggregates[f"{key}_p50"] = statistics.median(values)
aggregates[f"{key}_stdev"] = statistics.stdev(values)
if len(values) >= 4:
sorted_vals = sorted(values)
p95_idx = int(len(sorted_vals) * 0.95)
p99_idx = int(len(sorted_vals) * 0.99)
aggregates[f"{key}_p95"] = sorted_vals[p95_idx]
aggregates[f"{key}_p99"] = sorted_vals[p99_idx]
# Add counts
aggregates["total_requests"] = float(len(batch))
return aggregates
def get_stats(self) -> Dict[str, Any]:
"""Get current tracker statistics."""
return {