feat: add MLflow experiment tracking to all 4 Gradio UIs
Each UI now logs per-request metrics to MLflow: - llm.py: latency, tokens/sec, prompt/completion tokens (gradio-llm-tuning) - embeddings.py: latency, text length, batch size (gradio-embeddings-tuning) - stt.py: latency, audio duration, real-time factor (gradio-stt-tuning) - tts.py: latency, text length, audio duration (gradio-tts-tuning) Uses try/except guarded imports so UIs still work if MLflow is unreachable. Persistent run per Gradio instance, batched metric logging via MlflowClient.log_batch().
This commit is contained in:
97
llm.py
97
llm.py
@@ -30,6 +30,83 @@ LLM_URL = os.environ.get(
|
||||
"http://ai-inference-serve-svc.ai-ml.svc.cluster.local:8000/llm",
|
||||
)
|
||||
|
||||
# ─── MLflow experiment tracking ──────────────────────────────────────────
|
||||
try:
|
||||
import mlflow
|
||||
from mlflow.tracking import MlflowClient
|
||||
|
||||
MLFLOW_TRACKING_URI = os.environ.get(
|
||||
"MLFLOW_TRACKING_URI",
|
||||
"http://mlflow.mlflow.svc.cluster.local:80",
|
||||
)
|
||||
mlflow.set_tracking_uri(MLFLOW_TRACKING_URI)
|
||||
_mlflow_client = MlflowClient()
|
||||
|
||||
# Ensure experiment exists
|
||||
_experiment = _mlflow_client.get_experiment_by_name("gradio-llm-tuning")
|
||||
if _experiment is None:
|
||||
_experiment_id = _mlflow_client.create_experiment(
|
||||
"gradio-llm-tuning",
|
||||
artifact_location="/mlflow/artifacts/gradio-llm-tuning",
|
||||
)
|
||||
else:
|
||||
_experiment_id = _experiment.experiment_id
|
||||
|
||||
# One persistent run per Gradio instance
|
||||
_mlflow_run = mlflow.start_run(
|
||||
experiment_id=_experiment_id,
|
||||
run_name=f"gradio-llm-{os.environ.get('HOSTNAME', 'local')}",
|
||||
tags={
|
||||
"service": "gradio-llm",
|
||||
"endpoint": LLM_URL,
|
||||
"mlflow.runName": f"gradio-llm-{os.environ.get('HOSTNAME', 'local')}",
|
||||
},
|
||||
)
|
||||
_mlflow_run_id = _mlflow_run.info.run_id
|
||||
_mlflow_step = 0
|
||||
MLFLOW_ENABLED = True
|
||||
logger.info("MLflow tracking enabled: experiment=%s run=%s", _experiment_id, _mlflow_run_id)
|
||||
except Exception as exc:
|
||||
logger.warning("MLflow tracking disabled: %s", exc)
|
||||
_mlflow_client = None
|
||||
_mlflow_run_id = None
|
||||
_mlflow_step = 0
|
||||
MLFLOW_ENABLED = False
|
||||
|
||||
|
||||
def _log_llm_metrics(
|
||||
latency: float,
|
||||
prompt_tokens: int,
|
||||
completion_tokens: int,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
top_p: float,
|
||||
) -> None:
|
||||
"""Log inference metrics to MLflow (non-blocking best-effort)."""
|
||||
global _mlflow_step
|
||||
if not MLFLOW_ENABLED or _mlflow_client is None:
|
||||
return
|
||||
try:
|
||||
_mlflow_step += 1
|
||||
ts = int(time.time() * 1000)
|
||||
total_tokens = prompt_tokens + completion_tokens
|
||||
tps = completion_tokens / latency if latency > 0 else 0
|
||||
_mlflow_client.log_batch(
|
||||
_mlflow_run_id,
|
||||
metrics=[
|
||||
mlflow.entities.Metric("latency_s", latency, ts, _mlflow_step),
|
||||
mlflow.entities.Metric("prompt_tokens", prompt_tokens, ts, _mlflow_step),
|
||||
mlflow.entities.Metric("completion_tokens", completion_tokens, ts, _mlflow_step),
|
||||
mlflow.entities.Metric("total_tokens", total_tokens, ts, _mlflow_step),
|
||||
mlflow.entities.Metric("tokens_per_second", tps, ts, _mlflow_step),
|
||||
mlflow.entities.Metric("temperature", temperature, ts, _mlflow_step),
|
||||
mlflow.entities.Metric("max_tokens_requested", max_tokens, ts, _mlflow_step),
|
||||
mlflow.entities.Metric("top_p", top_p, ts, _mlflow_step),
|
||||
],
|
||||
)
|
||||
except Exception:
|
||||
logger.debug("MLflow log failed", exc_info=True)
|
||||
|
||||
DEFAULT_SYSTEM_PROMPT = (
|
||||
"You are a helpful AI assistant running on Davies Tech Labs homelab infrastructure. "
|
||||
"You are powered by Llama 3.1 70B served via vLLM on AMD Strix Halo (ROCm). "
|
||||
@@ -90,6 +167,16 @@ async def chat_stream(
|
||||
usage.get("completion_tokens", 0),
|
||||
)
|
||||
|
||||
# Log to MLflow
|
||||
_log_llm_metrics(
|
||||
latency=latency,
|
||||
prompt_tokens=usage.get("prompt_tokens", 0),
|
||||
completion_tokens=usage.get("completion_tokens", 0),
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
top_p=top_p,
|
||||
)
|
||||
|
||||
# Yield text progressively for a nicer streaming feel
|
||||
chunk_size = 4
|
||||
words = text.split(" ")
|
||||
@@ -164,6 +251,16 @@ def single_prompt(
|
||||
text = result["choices"][0]["message"]["content"]
|
||||
usage = result.get("usage", {})
|
||||
|
||||
# Log to MLflow
|
||||
_log_llm_metrics(
|
||||
latency=latency,
|
||||
prompt_tokens=usage.get("prompt_tokens", 0),
|
||||
completion_tokens=usage.get("completion_tokens", 0),
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
top_p=top_p,
|
||||
)
|
||||
|
||||
metrics = f"""
|
||||
**Generation Metrics:**
|
||||
- Latency: {latency:.1f}s
|
||||
|
||||
Reference in New Issue
Block a user