feat: add MLflow experiment tracking to all 4 Gradio UIs

Each UI now logs per-request metrics to MLflow:
- llm.py: latency, tokens/sec, prompt/completion tokens (gradio-llm-tuning)
- embeddings.py: latency, text length, batch size (gradio-embeddings-tuning)
- stt.py: latency, audio duration, real-time factor (gradio-stt-tuning)
- tts.py: latency, text length, audio duration (gradio-tts-tuning)

Uses try/except guarded imports so UIs still work if MLflow is
unreachable. Persistent run per Gradio instance, batched metric logging
via MlflowClient.log_batch().
This commit is contained in:
2026-02-13 07:54:06 -05:00
parent b2d2252342
commit 1c5dc7f751
4 changed files with 301 additions and 4 deletions

View File

@@ -30,10 +30,64 @@ EMBEDDINGS_URL = os.environ.get(
# Default: Ray Serve Embeddings endpoint # Default: Ray Serve Embeddings endpoint
"http://ai-inference-serve-svc.ai-ml.svc.cluster.local:8000/embeddings" "http://ai-inference-serve-svc.ai-ml.svc.cluster.local:8000/embeddings"
) )
MLFLOW_TRACKING_URI = os.environ.get( # ─── MLflow experiment tracking ──────────────────────────────────────────
try:
import mlflow
from mlflow.tracking import MlflowClient
MLFLOW_TRACKING_URI = os.environ.get(
"MLFLOW_TRACKING_URI", "MLFLOW_TRACKING_URI",
"http://mlflow.mlflow.svc.cluster.local:80" "http://mlflow.mlflow.svc.cluster.local:80",
) )
mlflow.set_tracking_uri(MLFLOW_TRACKING_URI)
_mlflow_client = MlflowClient()
_experiment = _mlflow_client.get_experiment_by_name("gradio-embeddings-tuning")
if _experiment is None:
_experiment_id = _mlflow_client.create_experiment(
"gradio-embeddings-tuning",
artifact_location="/mlflow/artifacts/gradio-embeddings-tuning",
)
else:
_experiment_id = _experiment.experiment_id
_mlflow_run = mlflow.start_run(
experiment_id=_experiment_id,
run_name=f"gradio-embeddings-{os.environ.get('HOSTNAME', 'local')}",
tags={"service": "gradio-embeddings", "endpoint": EMBEDDINGS_URL},
)
_mlflow_run_id = _mlflow_run.info.run_id
_mlflow_step = 0
MLFLOW_ENABLED = True
logger.info("MLflow tracking enabled: experiment=%s run=%s", _experiment_id, _mlflow_run_id)
except Exception as exc:
logger.warning("MLflow tracking disabled: %s", exc)
_mlflow_client = None
_mlflow_run_id = None
_mlflow_step = 0
MLFLOW_ENABLED = False
def _log_embedding_metrics(latency: float, batch_size: int, embedding_dims: int = 0) -> None:
"""Log embedding inference metrics to MLflow (non-blocking best-effort)."""
global _mlflow_step
if not MLFLOW_ENABLED or _mlflow_client is None:
return
try:
_mlflow_step += 1
ts = int(time.time() * 1000)
_mlflow_client.log_batch(
_mlflow_run_id,
metrics=[
mlflow.entities.Metric("latency_s", latency, ts, _mlflow_step),
mlflow.entities.Metric("batch_size", batch_size, ts, _mlflow_step),
mlflow.entities.Metric("embedding_dims", embedding_dims, ts, _mlflow_step),
mlflow.entities.Metric("latency_per_text_ms", (latency * 1000 / batch_size) if batch_size > 0 else 0, ts, _mlflow_step),
],
)
except Exception:
logger.debug("MLflow log failed", exc_info=True)
# HTTP client # HTTP client
client = httpx.Client(timeout=60.0) client = httpx.Client(timeout=60.0)
@@ -77,6 +131,9 @@ def generate_single_embedding(text: str) -> tuple[str, str, str]:
embedding = embeddings[0] embedding = embeddings[0]
dims = len(embedding) dims = len(embedding)
# Log to MLflow
_log_embedding_metrics(latency, batch_size=1, embedding_dims=dims)
# Format output # Format output
status = f"✅ Generated {dims}-dimensional embedding in {latency*1000:.1f}ms" status = f"✅ Generated {dims}-dimensional embedding in {latency*1000:.1f}ms"
@@ -119,6 +176,9 @@ def compare_texts(text1: str, text2: str) -> tuple[str, str]:
similarity = cosine_similarity(embeddings[0], embeddings[1]) similarity = cosine_similarity(embeddings[0], embeddings[1])
# Log to MLflow
_log_embedding_metrics(latency, batch_size=2, embedding_dims=len(embeddings[0]))
# Determine similarity level # Determine similarity level
if similarity > 0.9: if similarity > 0.9:
level = "🟢 Very High" level = "🟢 Very High"
@@ -167,6 +227,9 @@ def batch_embed(texts_input: str) -> tuple[str, str]:
try: try:
embeddings, latency = get_embeddings(texts) embeddings, latency = get_embeddings(texts)
# Log to MLflow
_log_embedding_metrics(latency, batch_size=len(embeddings), embedding_dims=len(embeddings[0]) if embeddings else 0)
status = f"✅ Generated {len(embeddings)} embeddings in {latency*1000:.1f}ms" status = f"✅ Generated {len(embeddings)} embeddings in {latency*1000:.1f}ms"
status += f" ({latency*1000/len(texts):.1f}ms per text)" status += f" ({latency*1000/len(texts):.1f}ms per text)"

97
llm.py
View File

@@ -30,6 +30,83 @@ LLM_URL = os.environ.get(
"http://ai-inference-serve-svc.ai-ml.svc.cluster.local:8000/llm", "http://ai-inference-serve-svc.ai-ml.svc.cluster.local:8000/llm",
) )
# ─── MLflow experiment tracking ──────────────────────────────────────────
try:
import mlflow
from mlflow.tracking import MlflowClient
MLFLOW_TRACKING_URI = os.environ.get(
"MLFLOW_TRACKING_URI",
"http://mlflow.mlflow.svc.cluster.local:80",
)
mlflow.set_tracking_uri(MLFLOW_TRACKING_URI)
_mlflow_client = MlflowClient()
# Ensure experiment exists
_experiment = _mlflow_client.get_experiment_by_name("gradio-llm-tuning")
if _experiment is None:
_experiment_id = _mlflow_client.create_experiment(
"gradio-llm-tuning",
artifact_location="/mlflow/artifacts/gradio-llm-tuning",
)
else:
_experiment_id = _experiment.experiment_id
# One persistent run per Gradio instance
_mlflow_run = mlflow.start_run(
experiment_id=_experiment_id,
run_name=f"gradio-llm-{os.environ.get('HOSTNAME', 'local')}",
tags={
"service": "gradio-llm",
"endpoint": LLM_URL,
"mlflow.runName": f"gradio-llm-{os.environ.get('HOSTNAME', 'local')}",
},
)
_mlflow_run_id = _mlflow_run.info.run_id
_mlflow_step = 0
MLFLOW_ENABLED = True
logger.info("MLflow tracking enabled: experiment=%s run=%s", _experiment_id, _mlflow_run_id)
except Exception as exc:
logger.warning("MLflow tracking disabled: %s", exc)
_mlflow_client = None
_mlflow_run_id = None
_mlflow_step = 0
MLFLOW_ENABLED = False
def _log_llm_metrics(
latency: float,
prompt_tokens: int,
completion_tokens: int,
temperature: float,
max_tokens: int,
top_p: float,
) -> None:
"""Log inference metrics to MLflow (non-blocking best-effort)."""
global _mlflow_step
if not MLFLOW_ENABLED or _mlflow_client is None:
return
try:
_mlflow_step += 1
ts = int(time.time() * 1000)
total_tokens = prompt_tokens + completion_tokens
tps = completion_tokens / latency if latency > 0 else 0
_mlflow_client.log_batch(
_mlflow_run_id,
metrics=[
mlflow.entities.Metric("latency_s", latency, ts, _mlflow_step),
mlflow.entities.Metric("prompt_tokens", prompt_tokens, ts, _mlflow_step),
mlflow.entities.Metric("completion_tokens", completion_tokens, ts, _mlflow_step),
mlflow.entities.Metric("total_tokens", total_tokens, ts, _mlflow_step),
mlflow.entities.Metric("tokens_per_second", tps, ts, _mlflow_step),
mlflow.entities.Metric("temperature", temperature, ts, _mlflow_step),
mlflow.entities.Metric("max_tokens_requested", max_tokens, ts, _mlflow_step),
mlflow.entities.Metric("top_p", top_p, ts, _mlflow_step),
],
)
except Exception:
logger.debug("MLflow log failed", exc_info=True)
DEFAULT_SYSTEM_PROMPT = ( DEFAULT_SYSTEM_PROMPT = (
"You are a helpful AI assistant running on Davies Tech Labs homelab infrastructure. " "You are a helpful AI assistant running on Davies Tech Labs homelab infrastructure. "
"You are powered by Llama 3.1 70B served via vLLM on AMD Strix Halo (ROCm). " "You are powered by Llama 3.1 70B served via vLLM on AMD Strix Halo (ROCm). "
@@ -90,6 +167,16 @@ async def chat_stream(
usage.get("completion_tokens", 0), usage.get("completion_tokens", 0),
) )
# Log to MLflow
_log_llm_metrics(
latency=latency,
prompt_tokens=usage.get("prompt_tokens", 0),
completion_tokens=usage.get("completion_tokens", 0),
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
)
# Yield text progressively for a nicer streaming feel # Yield text progressively for a nicer streaming feel
chunk_size = 4 chunk_size = 4
words = text.split(" ") words = text.split(" ")
@@ -164,6 +251,16 @@ def single_prompt(
text = result["choices"][0]["message"]["content"] text = result["choices"][0]["message"]["content"]
usage = result.get("usage", {}) usage = result.get("usage", {})
# Log to MLflow
_log_llm_metrics(
latency=latency,
prompt_tokens=usage.get("prompt_tokens", 0),
completion_tokens=usage.get("completion_tokens", 0),
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
)
metrics = f""" metrics = f"""
**Generation Metrics:** **Generation Metrics:**
- Latency: {latency:.1f}s - Latency: {latency:.1f}s

69
stt.py
View File

@@ -37,6 +37,67 @@ MLFLOW_TRACKING_URI = os.environ.get(
"http://mlflow.mlflow.svc.cluster.local:80" "http://mlflow.mlflow.svc.cluster.local:80"
) )
# ─── MLflow experiment tracking ──────────────────────────────────────────
try:
import mlflow
from mlflow.tracking import MlflowClient
mlflow.set_tracking_uri(MLFLOW_TRACKING_URI)
_mlflow_client = MlflowClient()
_experiment = _mlflow_client.get_experiment_by_name("gradio-stt-tuning")
if _experiment is None:
_experiment_id = _mlflow_client.create_experiment(
"gradio-stt-tuning",
artifact_location="/mlflow/artifacts/gradio-stt-tuning",
)
else:
_experiment_id = _experiment.experiment_id
_mlflow_run = mlflow.start_run(
experiment_id=_experiment_id,
run_name=f"gradio-stt-{os.environ.get('HOSTNAME', 'local')}",
tags={"service": "gradio-stt", "endpoint": STT_URL},
)
_mlflow_run_id = _mlflow_run.info.run_id
_mlflow_step = 0
MLFLOW_ENABLED = True
logger.info("MLflow tracking enabled: experiment=%s run=%s", _experiment_id, _mlflow_run_id)
except Exception as exc:
logger.warning("MLflow tracking disabled: %s", exc)
_mlflow_client = None
_mlflow_run_id = None
_mlflow_step = 0
MLFLOW_ENABLED = False
def _log_stt_metrics(
latency: float, audio_duration: float, word_count: int, task: str,
) -> None:
"""Log STT inference metrics to MLflow (non-blocking best-effort)."""
global _mlflow_step
if not MLFLOW_ENABLED or _mlflow_client is None:
return
try:
_mlflow_step += 1
ts = int(time.time() * 1000)
rtf = latency / audio_duration if audio_duration > 0 else 0
_mlflow_client.log_batch(
_mlflow_run_id,
metrics=[
mlflow.entities.Metric("latency_s", latency, ts, _mlflow_step),
mlflow.entities.Metric("audio_duration_s", audio_duration, ts, _mlflow_step),
mlflow.entities.Metric("realtime_factor", rtf, ts, _mlflow_step),
mlflow.entities.Metric("word_count", word_count, ts, _mlflow_step),
],
params=[] if _mlflow_step > 1 else [
mlflow.entities.Param("task", task),
],
)
except Exception:
logger.debug("MLflow log failed", exc_info=True)
# HTTP client with longer timeout for transcription # HTTP client with longer timeout for transcription
client = httpx.Client(timeout=180.0) client = httpx.Client(timeout=180.0)
@@ -117,6 +178,14 @@ def transcribe_audio(
text = result.get("text", "") text = result.get("text", "")
detected_language = result.get("language", "unknown") detected_language = result.get("language", "unknown")
# Log to MLflow
_log_stt_metrics(
latency=latency,
audio_duration=audio_duration,
word_count=len(text.split()),
task=task,
)
# Status message # Status message
status = f"✅ Transcribed {audio_duration:.1f}s of audio in {latency*1000:.0f}ms" status = f"✅ Transcribed {audio_duration:.1f}s of audio in {latency*1000:.0f}ms"

68
tts.py
View File

@@ -37,6 +37,66 @@ MLFLOW_TRACKING_URI = os.environ.get(
"http://mlflow.mlflow.svc.cluster.local:80" "http://mlflow.mlflow.svc.cluster.local:80"
) )
# ─── MLflow experiment tracking ──────────────────────────────────────────
try:
import mlflow
from mlflow.tracking import MlflowClient
mlflow.set_tracking_uri(MLFLOW_TRACKING_URI)
_mlflow_client = MlflowClient()
_experiment = _mlflow_client.get_experiment_by_name("gradio-tts-tuning")
if _experiment is None:
_experiment_id = _mlflow_client.create_experiment(
"gradio-tts-tuning",
artifact_location="/mlflow/artifacts/gradio-tts-tuning",
)
else:
_experiment_id = _experiment.experiment_id
_mlflow_run = mlflow.start_run(
experiment_id=_experiment_id,
run_name=f"gradio-tts-{os.environ.get('HOSTNAME', 'local')}",
tags={"service": "gradio-tts", "endpoint": TTS_URL},
)
_mlflow_run_id = _mlflow_run.info.run_id
_mlflow_step = 0
MLFLOW_ENABLED = True
logger.info("MLflow tracking enabled: experiment=%s run=%s", _experiment_id, _mlflow_run_id)
except Exception as exc:
logger.warning("MLflow tracking disabled: %s", exc)
_mlflow_client = None
_mlflow_run_id = None
_mlflow_step = 0
MLFLOW_ENABLED = False
def _log_tts_metrics(
latency: float, audio_duration: float, text_chars: int, language: str,
) -> None:
"""Log TTS inference metrics to MLflow (non-blocking best-effort)."""
global _mlflow_step
if not MLFLOW_ENABLED or _mlflow_client is None:
return
try:
_mlflow_step += 1
ts = int(time.time() * 1000)
rtf = latency / audio_duration if audio_duration > 0 else 0
cps = text_chars / latency if latency > 0 else 0
_mlflow_client.log_batch(
_mlflow_run_id,
metrics=[
mlflow.entities.Metric("latency_s", latency, ts, _mlflow_step),
mlflow.entities.Metric("audio_duration_s", audio_duration, ts, _mlflow_step),
mlflow.entities.Metric("realtime_factor", rtf, ts, _mlflow_step),
mlflow.entities.Metric("chars_per_second", cps, ts, _mlflow_step),
mlflow.entities.Metric("text_chars", text_chars, ts, _mlflow_step),
],
)
except Exception:
logger.debug("MLflow log failed", exc_info=True)
# HTTP client with longer timeout for audio generation # HTTP client with longer timeout for audio generation
client = httpx.Client(timeout=120.0) client = httpx.Client(timeout=120.0)
@@ -94,6 +154,14 @@ def synthesize_speech(text: str, language: str) -> tuple[str, tuple[int, np.ndar
# Status message # Status message
status = f"✅ Generated {duration:.2f}s of audio in {latency*1000:.0f}ms" status = f"✅ Generated {duration:.2f}s of audio in {latency*1000:.0f}ms"
# Log to MLflow
_log_tts_metrics(
latency=latency,
audio_duration=duration,
text_chars=len(text),
language=lang_code,
)
# Metrics # Metrics
metrics = f""" metrics = f"""
**Audio Statistics:** **Audio Statistics:**