feat: add MLflow experiment tracking to all 4 Gradio UIs

Each UI now logs per-request metrics to MLflow:
- llm.py: latency, tokens/sec, prompt/completion tokens (gradio-llm-tuning)
- embeddings.py: latency, text length, batch size (gradio-embeddings-tuning)
- stt.py: latency, audio duration, real-time factor (gradio-stt-tuning)
- tts.py: latency, text length, audio duration (gradio-tts-tuning)

Uses try/except guarded imports so UIs still work if MLflow is
unreachable. Persistent run per Gradio instance, batched metric logging
via MlflowClient.log_batch().
This commit is contained in:
2026-02-13 07:54:06 -05:00
parent b2d2252342
commit 1c5dc7f751
4 changed files with 301 additions and 4 deletions

69
stt.py
View File

@@ -37,6 +37,67 @@ MLFLOW_TRACKING_URI = os.environ.get(
"http://mlflow.mlflow.svc.cluster.local:80"
)
# ─── MLflow experiment tracking ──────────────────────────────────────────
try:
import mlflow
from mlflow.tracking import MlflowClient
mlflow.set_tracking_uri(MLFLOW_TRACKING_URI)
_mlflow_client = MlflowClient()
_experiment = _mlflow_client.get_experiment_by_name("gradio-stt-tuning")
if _experiment is None:
_experiment_id = _mlflow_client.create_experiment(
"gradio-stt-tuning",
artifact_location="/mlflow/artifacts/gradio-stt-tuning",
)
else:
_experiment_id = _experiment.experiment_id
_mlflow_run = mlflow.start_run(
experiment_id=_experiment_id,
run_name=f"gradio-stt-{os.environ.get('HOSTNAME', 'local')}",
tags={"service": "gradio-stt", "endpoint": STT_URL},
)
_mlflow_run_id = _mlflow_run.info.run_id
_mlflow_step = 0
MLFLOW_ENABLED = True
logger.info("MLflow tracking enabled: experiment=%s run=%s", _experiment_id, _mlflow_run_id)
except Exception as exc:
logger.warning("MLflow tracking disabled: %s", exc)
_mlflow_client = None
_mlflow_run_id = None
_mlflow_step = 0
MLFLOW_ENABLED = False
def _log_stt_metrics(
latency: float, audio_duration: float, word_count: int, task: str,
) -> None:
"""Log STT inference metrics to MLflow (non-blocking best-effort)."""
global _mlflow_step
if not MLFLOW_ENABLED or _mlflow_client is None:
return
try:
_mlflow_step += 1
ts = int(time.time() * 1000)
rtf = latency / audio_duration if audio_duration > 0 else 0
_mlflow_client.log_batch(
_mlflow_run_id,
metrics=[
mlflow.entities.Metric("latency_s", latency, ts, _mlflow_step),
mlflow.entities.Metric("audio_duration_s", audio_duration, ts, _mlflow_step),
mlflow.entities.Metric("realtime_factor", rtf, ts, _mlflow_step),
mlflow.entities.Metric("word_count", word_count, ts, _mlflow_step),
],
params=[] if _mlflow_step > 1 else [
mlflow.entities.Param("task", task),
],
)
except Exception:
logger.debug("MLflow log failed", exc_info=True)
# HTTP client with longer timeout for transcription
client = httpx.Client(timeout=180.0)
@@ -116,6 +177,14 @@ def transcribe_audio(
text = result.get("text", "")
detected_language = result.get("language", "unknown")
# Log to MLflow
_log_stt_metrics(
latency=latency,
audio_duration=audio_duration,
word_count=len(text.split()),
task=task,
)
# Status message
status = f"✅ Transcribed {audio_duration:.1f}s of audio in {latency*1000:.0f}ms"