style: apply ruff format to all files
Some checks failed
CI / Lint (push) Successful in 1m46s
CI / Test (push) Successful in 1m44s
CI / Publish (push) Failing after 19s
CI / Notify (push) Successful in 1s

This commit is contained in:
2026-02-13 11:05:26 -05:00
parent 1c841729a0
commit ca5bef9664
7 changed files with 89 additions and 222 deletions

View File

@@ -29,6 +29,7 @@ logger = logging.getLogger(__name__)
@dataclass
class InferenceMetrics:
"""Metrics collected during an inference request."""
request_id: str
user_id: Optional[str] = None
session_id: Optional[str] = None
@@ -190,31 +191,22 @@ class InferenceMetricsTracker:
# Initialize MLflow in thread pool to avoid blocking
loop = asyncio.get_event_loop()
await loop.run_in_executor(
self._executor,
self._init_mlflow
)
await loop.run_in_executor(self._executor, self._init_mlflow)
if self.enable_batching:
self._flush_task = asyncio.create_task(self._periodic_flush())
logger.info(
f"InferenceMetricsTracker started for {self.service_name} "
f"(experiment: {self.experiment_name})"
)
logger.info(f"InferenceMetricsTracker started for {self.service_name} (experiment: {self.experiment_name})")
def _init_mlflow(self) -> None:
"""Initialize MLflow client and experiment (runs in thread pool)."""
self._client = get_mlflow_client(
tracking_uri=self.tracking_uri,
configure_global=True
)
self._client = get_mlflow_client(tracking_uri=self.tracking_uri, configure_global=True)
self._experiment_id = ensure_experiment(
self.experiment_name,
tags={
"service": self.service_name,
"type": "inference-metrics",
}
},
)
async def stop(self) -> None:
@@ -265,10 +257,7 @@ class InferenceMetricsTracker:
else:
# Immediate logging in thread pool
loop = asyncio.get_event_loop()
await loop.run_in_executor(
self._executor,
partial(self._log_single_inference, metrics)
)
await loop.run_in_executor(self._executor, partial(self._log_single_inference, metrics))
async def _periodic_flush(self) -> None:
"""Periodically flush batched metrics."""
@@ -287,10 +276,7 @@ class InferenceMetricsTracker:
# Log in thread pool
loop = asyncio.get_event_loop()
await loop.run_in_executor(
self._executor,
partial(self._log_batch, batch)
)
await loop.run_in_executor(self._executor, partial(self._log_batch, batch))
def _log_single_inference(self, metrics: InferenceMetrics) -> None:
"""Log a single inference request to MLflow (runs in thread pool)."""
@@ -302,7 +288,7 @@ class InferenceMetricsTracker:
"service": self.service_name,
"request_id": metrics.request_id,
"type": "single-inference",
}
},
):
mlflow.log_params(metrics.as_params_dict())
mlflow.log_metrics(metrics.as_metrics_dict())
@@ -336,7 +322,7 @@ class InferenceMetricsTracker:
"service": self.service_name,
"type": "batch-inference",
"batch_size": str(len(batch)),
}
},
):
# Log aggregate metrics
mlflow.log_metrics(aggregates)
@@ -352,12 +338,14 @@ class InferenceMetricsTracker:
premium_count = sum(1 for m in batch if m.is_premium)
error_count = sum(1 for m in batch if m.has_error)
mlflow.log_metrics({
"rag_enabled_pct": rag_enabled_count / len(batch) * 100,
"streaming_pct": streaming_count / len(batch) * 100,
"premium_pct": premium_count / len(batch) * 100,
"error_rate": error_count / len(batch) * 100,
})
mlflow.log_metrics(
{
"rag_enabled_pct": rag_enabled_count / len(batch) * 100,
"streaming_pct": streaming_count / len(batch) * 100,
"premium_pct": premium_count / len(batch) * 100,
"error_rate": error_count / len(batch) * 100,
}
)
# Log model distribution
model_counts: Dict[str, int] = defaultdict(int)
@@ -366,20 +354,14 @@ class InferenceMetricsTracker:
model_counts[m.model_name] += 1
if model_counts:
mlflow.log_dict(
{"models": dict(model_counts)},
"model_distribution.json"
)
mlflow.log_dict({"models": dict(model_counts)}, "model_distribution.json")
logger.info(f"Logged batch of {len(batch)} inference metrics")
except Exception as e:
logger.error(f"Failed to log batch metrics: {e}")
def _calculate_aggregates(
self,
batch: List[InferenceMetrics]
) -> Dict[str, float]:
def _calculate_aggregates(self, batch: List[InferenceMetrics]) -> Dict[str, float]:
"""Calculate aggregate statistics from a batch of metrics."""
import statistics