style: apply ruff format to all files
This commit is contained in:
@@ -29,6 +29,7 @@ logger = logging.getLogger(__name__)
|
||||
@dataclass
|
||||
class InferenceMetrics:
|
||||
"""Metrics collected during an inference request."""
|
||||
|
||||
request_id: str
|
||||
user_id: Optional[str] = None
|
||||
session_id: Optional[str] = None
|
||||
@@ -190,31 +191,22 @@ class InferenceMetricsTracker:
|
||||
|
||||
# Initialize MLflow in thread pool to avoid blocking
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(
|
||||
self._executor,
|
||||
self._init_mlflow
|
||||
)
|
||||
await loop.run_in_executor(self._executor, self._init_mlflow)
|
||||
|
||||
if self.enable_batching:
|
||||
self._flush_task = asyncio.create_task(self._periodic_flush())
|
||||
|
||||
logger.info(
|
||||
f"InferenceMetricsTracker started for {self.service_name} "
|
||||
f"(experiment: {self.experiment_name})"
|
||||
)
|
||||
logger.info(f"InferenceMetricsTracker started for {self.service_name} (experiment: {self.experiment_name})")
|
||||
|
||||
def _init_mlflow(self) -> None:
|
||||
"""Initialize MLflow client and experiment (runs in thread pool)."""
|
||||
self._client = get_mlflow_client(
|
||||
tracking_uri=self.tracking_uri,
|
||||
configure_global=True
|
||||
)
|
||||
self._client = get_mlflow_client(tracking_uri=self.tracking_uri, configure_global=True)
|
||||
self._experiment_id = ensure_experiment(
|
||||
self.experiment_name,
|
||||
tags={
|
||||
"service": self.service_name,
|
||||
"type": "inference-metrics",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
async def stop(self) -> None:
|
||||
@@ -265,10 +257,7 @@ class InferenceMetricsTracker:
|
||||
else:
|
||||
# Immediate logging in thread pool
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(
|
||||
self._executor,
|
||||
partial(self._log_single_inference, metrics)
|
||||
)
|
||||
await loop.run_in_executor(self._executor, partial(self._log_single_inference, metrics))
|
||||
|
||||
async def _periodic_flush(self) -> None:
|
||||
"""Periodically flush batched metrics."""
|
||||
@@ -287,10 +276,7 @@ class InferenceMetricsTracker:
|
||||
|
||||
# Log in thread pool
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(
|
||||
self._executor,
|
||||
partial(self._log_batch, batch)
|
||||
)
|
||||
await loop.run_in_executor(self._executor, partial(self._log_batch, batch))
|
||||
|
||||
def _log_single_inference(self, metrics: InferenceMetrics) -> None:
|
||||
"""Log a single inference request to MLflow (runs in thread pool)."""
|
||||
@@ -302,7 +288,7 @@ class InferenceMetricsTracker:
|
||||
"service": self.service_name,
|
||||
"request_id": metrics.request_id,
|
||||
"type": "single-inference",
|
||||
}
|
||||
},
|
||||
):
|
||||
mlflow.log_params(metrics.as_params_dict())
|
||||
mlflow.log_metrics(metrics.as_metrics_dict())
|
||||
@@ -336,7 +322,7 @@ class InferenceMetricsTracker:
|
||||
"service": self.service_name,
|
||||
"type": "batch-inference",
|
||||
"batch_size": str(len(batch)),
|
||||
}
|
||||
},
|
||||
):
|
||||
# Log aggregate metrics
|
||||
mlflow.log_metrics(aggregates)
|
||||
@@ -352,12 +338,14 @@ class InferenceMetricsTracker:
|
||||
premium_count = sum(1 for m in batch if m.is_premium)
|
||||
error_count = sum(1 for m in batch if m.has_error)
|
||||
|
||||
mlflow.log_metrics({
|
||||
"rag_enabled_pct": rag_enabled_count / len(batch) * 100,
|
||||
"streaming_pct": streaming_count / len(batch) * 100,
|
||||
"premium_pct": premium_count / len(batch) * 100,
|
||||
"error_rate": error_count / len(batch) * 100,
|
||||
})
|
||||
mlflow.log_metrics(
|
||||
{
|
||||
"rag_enabled_pct": rag_enabled_count / len(batch) * 100,
|
||||
"streaming_pct": streaming_count / len(batch) * 100,
|
||||
"premium_pct": premium_count / len(batch) * 100,
|
||||
"error_rate": error_count / len(batch) * 100,
|
||||
}
|
||||
)
|
||||
|
||||
# Log model distribution
|
||||
model_counts: Dict[str, int] = defaultdict(int)
|
||||
@@ -366,20 +354,14 @@ class InferenceMetricsTracker:
|
||||
model_counts[m.model_name] += 1
|
||||
|
||||
if model_counts:
|
||||
mlflow.log_dict(
|
||||
{"models": dict(model_counts)},
|
||||
"model_distribution.json"
|
||||
)
|
||||
mlflow.log_dict({"models": dict(model_counts)}, "model_distribution.json")
|
||||
|
||||
logger.info(f"Logged batch of {len(batch)} inference metrics")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to log batch metrics: {e}")
|
||||
|
||||
def _calculate_aggregates(
|
||||
self,
|
||||
batch: List[InferenceMetrics]
|
||||
) -> Dict[str, float]:
|
||||
def _calculate_aggregates(self, batch: List[InferenceMetrics]) -> Dict[str, float]:
|
||||
"""Calculate aggregate statistics from a batch of metrics."""
|
||||
import statistics
|
||||
|
||||
|
||||
Reference in New Issue
Block a user