diff --git a/embeddings.py b/embeddings.py index 6adc0ec..cae7867 100644 --- a/embeddings.py +++ b/embeddings.py @@ -30,10 +30,64 @@ EMBEDDINGS_URL = os.environ.get( # Default: Ray Serve Embeddings endpoint "http://ai-inference-serve-svc.ai-ml.svc.cluster.local:8000/embeddings" ) -MLFLOW_TRACKING_URI = os.environ.get( - "MLFLOW_TRACKING_URI", - "http://mlflow.mlflow.svc.cluster.local:80" -) +# ─── MLflow experiment tracking ────────────────────────────────────────── +try: + import mlflow + from mlflow.tracking import MlflowClient + + MLFLOW_TRACKING_URI = os.environ.get( + "MLFLOW_TRACKING_URI", + "http://mlflow.mlflow.svc.cluster.local:80", + ) + mlflow.set_tracking_uri(MLFLOW_TRACKING_URI) + _mlflow_client = MlflowClient() + + _experiment = _mlflow_client.get_experiment_by_name("gradio-embeddings-tuning") + if _experiment is None: + _experiment_id = _mlflow_client.create_experiment( + "gradio-embeddings-tuning", + artifact_location="/mlflow/artifacts/gradio-embeddings-tuning", + ) + else: + _experiment_id = _experiment.experiment_id + + _mlflow_run = mlflow.start_run( + experiment_id=_experiment_id, + run_name=f"gradio-embeddings-{os.environ.get('HOSTNAME', 'local')}", + tags={"service": "gradio-embeddings", "endpoint": EMBEDDINGS_URL}, + ) + _mlflow_run_id = _mlflow_run.info.run_id + _mlflow_step = 0 + MLFLOW_ENABLED = True + logger.info("MLflow tracking enabled: experiment=%s run=%s", _experiment_id, _mlflow_run_id) +except Exception as exc: + logger.warning("MLflow tracking disabled: %s", exc) + _mlflow_client = None + _mlflow_run_id = None + _mlflow_step = 0 + MLFLOW_ENABLED = False + + +def _log_embedding_metrics(latency: float, batch_size: int, embedding_dims: int = 0) -> None: + """Log embedding inference metrics to MLflow (non-blocking best-effort).""" + global _mlflow_step + if not MLFLOW_ENABLED or _mlflow_client is None: + return + try: + _mlflow_step += 1 + ts = int(time.time() * 1000) + _mlflow_client.log_batch( + _mlflow_run_id, + metrics=[ + mlflow.entities.Metric("latency_s", latency, ts, _mlflow_step), + mlflow.entities.Metric("batch_size", batch_size, ts, _mlflow_step), + mlflow.entities.Metric("embedding_dims", embedding_dims, ts, _mlflow_step), + mlflow.entities.Metric("latency_per_text_ms", (latency * 1000 / batch_size) if batch_size > 0 else 0, ts, _mlflow_step), + ], + ) + except Exception: + logger.debug("MLflow log failed", exc_info=True) + # HTTP client client = httpx.Client(timeout=60.0) @@ -76,6 +130,9 @@ def generate_single_embedding(text: str) -> tuple[str, str, str]: embedding = embeddings[0] dims = len(embedding) + + # Log to MLflow + _log_embedding_metrics(latency, batch_size=1, embedding_dims=dims) # Format output status = f"✅ Generated {dims}-dimensional embedding in {latency*1000:.1f}ms" @@ -118,6 +175,9 @@ def compare_texts(text1: str, text2: str) -> tuple[str, str]: return "❌ Failed to get embeddings for both texts", "" similarity = cosine_similarity(embeddings[0], embeddings[1]) + + # Log to MLflow + _log_embedding_metrics(latency, batch_size=2, embedding_dims=len(embeddings[0])) # Determine similarity level if similarity > 0.9: @@ -167,6 +227,9 @@ def batch_embed(texts_input: str) -> tuple[str, str]: try: embeddings, latency = get_embeddings(texts) + # Log to MLflow + _log_embedding_metrics(latency, batch_size=len(embeddings), embedding_dims=len(embeddings[0]) if embeddings else 0) + status = f"✅ Generated {len(embeddings)} embeddings in {latency*1000:.1f}ms" status += f" ({latency*1000/len(texts):.1f}ms per text)" diff --git a/llm.py b/llm.py index 0fba472..d49b300 100644 --- a/llm.py +++ b/llm.py @@ -30,6 +30,83 @@ LLM_URL = os.environ.get( "http://ai-inference-serve-svc.ai-ml.svc.cluster.local:8000/llm", ) +# ─── MLflow experiment tracking ────────────────────────────────────────── +try: + import mlflow + from mlflow.tracking import MlflowClient + + MLFLOW_TRACKING_URI = os.environ.get( + "MLFLOW_TRACKING_URI", + "http://mlflow.mlflow.svc.cluster.local:80", + ) + mlflow.set_tracking_uri(MLFLOW_TRACKING_URI) + _mlflow_client = MlflowClient() + + # Ensure experiment exists + _experiment = _mlflow_client.get_experiment_by_name("gradio-llm-tuning") + if _experiment is None: + _experiment_id = _mlflow_client.create_experiment( + "gradio-llm-tuning", + artifact_location="/mlflow/artifacts/gradio-llm-tuning", + ) + else: + _experiment_id = _experiment.experiment_id + + # One persistent run per Gradio instance + _mlflow_run = mlflow.start_run( + experiment_id=_experiment_id, + run_name=f"gradio-llm-{os.environ.get('HOSTNAME', 'local')}", + tags={ + "service": "gradio-llm", + "endpoint": LLM_URL, + "mlflow.runName": f"gradio-llm-{os.environ.get('HOSTNAME', 'local')}", + }, + ) + _mlflow_run_id = _mlflow_run.info.run_id + _mlflow_step = 0 + MLFLOW_ENABLED = True + logger.info("MLflow tracking enabled: experiment=%s run=%s", _experiment_id, _mlflow_run_id) +except Exception as exc: + logger.warning("MLflow tracking disabled: %s", exc) + _mlflow_client = None + _mlflow_run_id = None + _mlflow_step = 0 + MLFLOW_ENABLED = False + + +def _log_llm_metrics( + latency: float, + prompt_tokens: int, + completion_tokens: int, + temperature: float, + max_tokens: int, + top_p: float, +) -> None: + """Log inference metrics to MLflow (non-blocking best-effort).""" + global _mlflow_step + if not MLFLOW_ENABLED or _mlflow_client is None: + return + try: + _mlflow_step += 1 + ts = int(time.time() * 1000) + total_tokens = prompt_tokens + completion_tokens + tps = completion_tokens / latency if latency > 0 else 0 + _mlflow_client.log_batch( + _mlflow_run_id, + metrics=[ + mlflow.entities.Metric("latency_s", latency, ts, _mlflow_step), + mlflow.entities.Metric("prompt_tokens", prompt_tokens, ts, _mlflow_step), + mlflow.entities.Metric("completion_tokens", completion_tokens, ts, _mlflow_step), + mlflow.entities.Metric("total_tokens", total_tokens, ts, _mlflow_step), + mlflow.entities.Metric("tokens_per_second", tps, ts, _mlflow_step), + mlflow.entities.Metric("temperature", temperature, ts, _mlflow_step), + mlflow.entities.Metric("max_tokens_requested", max_tokens, ts, _mlflow_step), + mlflow.entities.Metric("top_p", top_p, ts, _mlflow_step), + ], + ) + except Exception: + logger.debug("MLflow log failed", exc_info=True) + DEFAULT_SYSTEM_PROMPT = ( "You are a helpful AI assistant running on Davies Tech Labs homelab infrastructure. " "You are powered by Llama 3.1 70B served via vLLM on AMD Strix Halo (ROCm). " @@ -90,6 +167,16 @@ async def chat_stream( usage.get("completion_tokens", 0), ) + # Log to MLflow + _log_llm_metrics( + latency=latency, + prompt_tokens=usage.get("prompt_tokens", 0), + completion_tokens=usage.get("completion_tokens", 0), + temperature=temperature, + max_tokens=max_tokens, + top_p=top_p, + ) + # Yield text progressively for a nicer streaming feel chunk_size = 4 words = text.split(" ") @@ -164,6 +251,16 @@ def single_prompt( text = result["choices"][0]["message"]["content"] usage = result.get("usage", {}) + # Log to MLflow + _log_llm_metrics( + latency=latency, + prompt_tokens=usage.get("prompt_tokens", 0), + completion_tokens=usage.get("completion_tokens", 0), + temperature=temperature, + max_tokens=max_tokens, + top_p=top_p, + ) + metrics = f""" **Generation Metrics:** - Latency: {latency:.1f}s diff --git a/stt.py b/stt.py index 95da4ef..cc555ea 100644 --- a/stt.py +++ b/stt.py @@ -37,6 +37,67 @@ MLFLOW_TRACKING_URI = os.environ.get( "http://mlflow.mlflow.svc.cluster.local:80" ) +# ─── MLflow experiment tracking ────────────────────────────────────────── +try: + import mlflow + from mlflow.tracking import MlflowClient + + mlflow.set_tracking_uri(MLFLOW_TRACKING_URI) + _mlflow_client = MlflowClient() + + _experiment = _mlflow_client.get_experiment_by_name("gradio-stt-tuning") + if _experiment is None: + _experiment_id = _mlflow_client.create_experiment( + "gradio-stt-tuning", + artifact_location="/mlflow/artifacts/gradio-stt-tuning", + ) + else: + _experiment_id = _experiment.experiment_id + + _mlflow_run = mlflow.start_run( + experiment_id=_experiment_id, + run_name=f"gradio-stt-{os.environ.get('HOSTNAME', 'local')}", + tags={"service": "gradio-stt", "endpoint": STT_URL}, + ) + _mlflow_run_id = _mlflow_run.info.run_id + _mlflow_step = 0 + MLFLOW_ENABLED = True + logger.info("MLflow tracking enabled: experiment=%s run=%s", _experiment_id, _mlflow_run_id) +except Exception as exc: + logger.warning("MLflow tracking disabled: %s", exc) + _mlflow_client = None + _mlflow_run_id = None + _mlflow_step = 0 + MLFLOW_ENABLED = False + + +def _log_stt_metrics( + latency: float, audio_duration: float, word_count: int, task: str, +) -> None: + """Log STT inference metrics to MLflow (non-blocking best-effort).""" + global _mlflow_step + if not MLFLOW_ENABLED or _mlflow_client is None: + return + try: + _mlflow_step += 1 + ts = int(time.time() * 1000) + rtf = latency / audio_duration if audio_duration > 0 else 0 + _mlflow_client.log_batch( + _mlflow_run_id, + metrics=[ + mlflow.entities.Metric("latency_s", latency, ts, _mlflow_step), + mlflow.entities.Metric("audio_duration_s", audio_duration, ts, _mlflow_step), + mlflow.entities.Metric("realtime_factor", rtf, ts, _mlflow_step), + mlflow.entities.Metric("word_count", word_count, ts, _mlflow_step), + ], + params=[] if _mlflow_step > 1 else [ + mlflow.entities.Param("task", task), + ], + ) + except Exception: + logger.debug("MLflow log failed", exc_info=True) + + # HTTP client with longer timeout for transcription client = httpx.Client(timeout=180.0) @@ -116,6 +177,14 @@ def transcribe_audio( text = result.get("text", "") detected_language = result.get("language", "unknown") + + # Log to MLflow + _log_stt_metrics( + latency=latency, + audio_duration=audio_duration, + word_count=len(text.split()), + task=task, + ) # Status message status = f"✅ Transcribed {audio_duration:.1f}s of audio in {latency*1000:.0f}ms" diff --git a/tts.py b/tts.py index 7041d7f..742d535 100644 --- a/tts.py +++ b/tts.py @@ -37,6 +37,66 @@ MLFLOW_TRACKING_URI = os.environ.get( "http://mlflow.mlflow.svc.cluster.local:80" ) +# ─── MLflow experiment tracking ────────────────────────────────────────── +try: + import mlflow + from mlflow.tracking import MlflowClient + + mlflow.set_tracking_uri(MLFLOW_TRACKING_URI) + _mlflow_client = MlflowClient() + + _experiment = _mlflow_client.get_experiment_by_name("gradio-tts-tuning") + if _experiment is None: + _experiment_id = _mlflow_client.create_experiment( + "gradio-tts-tuning", + artifact_location="/mlflow/artifacts/gradio-tts-tuning", + ) + else: + _experiment_id = _experiment.experiment_id + + _mlflow_run = mlflow.start_run( + experiment_id=_experiment_id, + run_name=f"gradio-tts-{os.environ.get('HOSTNAME', 'local')}", + tags={"service": "gradio-tts", "endpoint": TTS_URL}, + ) + _mlflow_run_id = _mlflow_run.info.run_id + _mlflow_step = 0 + MLFLOW_ENABLED = True + logger.info("MLflow tracking enabled: experiment=%s run=%s", _experiment_id, _mlflow_run_id) +except Exception as exc: + logger.warning("MLflow tracking disabled: %s", exc) + _mlflow_client = None + _mlflow_run_id = None + _mlflow_step = 0 + MLFLOW_ENABLED = False + + +def _log_tts_metrics( + latency: float, audio_duration: float, text_chars: int, language: str, +) -> None: + """Log TTS inference metrics to MLflow (non-blocking best-effort).""" + global _mlflow_step + if not MLFLOW_ENABLED or _mlflow_client is None: + return + try: + _mlflow_step += 1 + ts = int(time.time() * 1000) + rtf = latency / audio_duration if audio_duration > 0 else 0 + cps = text_chars / latency if latency > 0 else 0 + _mlflow_client.log_batch( + _mlflow_run_id, + metrics=[ + mlflow.entities.Metric("latency_s", latency, ts, _mlflow_step), + mlflow.entities.Metric("audio_duration_s", audio_duration, ts, _mlflow_step), + mlflow.entities.Metric("realtime_factor", rtf, ts, _mlflow_step), + mlflow.entities.Metric("chars_per_second", cps, ts, _mlflow_step), + mlflow.entities.Metric("text_chars", text_chars, ts, _mlflow_step), + ], + ) + except Exception: + logger.debug("MLflow log failed", exc_info=True) + + # HTTP client with longer timeout for audio generation client = httpx.Client(timeout=120.0) @@ -93,6 +153,14 @@ def synthesize_speech(text: str, language: str) -> tuple[str, tuple[int, np.ndar # Status message status = f"✅ Generated {duration:.2f}s of audio in {latency*1000:.0f}ms" + + # Log to MLflow + _log_tts_metrics( + latency=latency, + audio_duration=duration, + text_chars=len(text), + language=lang_code, + ) # Metrics metrics = f"""