"""Log avatar generation results to MLflow via REST API. Uses the same lightweight REST-only approach as ray-serve's mlflow_logger.py — no heavyweight mlflow SDK dependency. Usage: from avatar_pipeline.log_mlflow import log_generation log_generation( avatar_name="Silver-Mage", params={"trellis_seed": 42, "trellis_simplify": 0.95}, metrics={"vertex_count": 12345, "face_count": 8000, "duration_s": 45.2}, artifacts={"vrm": Path("exports/Silver-Mage.vrm")}, ) """ from __future__ import annotations import json import logging import os import time import urllib.error import urllib.request from pathlib import Path logger = logging.getLogger(__name__) DEFAULT_TRACKING_URI = "http://mlflow.mlflow.svc.cluster.local:80" EXPERIMENT_NAME = "3d-avatar-generation" def _base_url() -> str: return os.environ.get("MLFLOW_TRACKING_URI", DEFAULT_TRACKING_URI).rstrip("/") def _post(path: str, body: dict) -> dict: url = f"{_base_url()}/api/2.0/mlflow/{path}" data = json.dumps(body).encode() req = urllib.request.Request( url, data=data, headers={"Content-Type": "application/json"}, method="POST" ) with urllib.request.urlopen(req, timeout=15) as resp: return json.loads(resp.read().decode()) def _get_or_create_experiment(name: str) -> str: try: resp = _post("experiments/get-by-name", {"experiment_name": name}) return resp["experiment"]["experiment_id"] except urllib.error.HTTPError: resp = _post("experiments/create", {"name": name}) return resp["experiment_id"] def _create_run(experiment_id: str, run_name: str, tags: dict[str, str]) -> str: tag_list = [ {"key": k, "value": v} for k, v in { "mlflow.runName": run_name, "mlflow.source.type": "LOCAL", "mlflow.source.name": "avatar-pipeline", "hostname": os.environ.get("HOSTNAME", "desktop"), **tags, }.items() ] resp = _post( "runs/create", { "experiment_id": experiment_id, "run_name": run_name, "start_time": int(time.time() * 1000), "tags": tag_list, }, ) return resp["run"]["info"]["run_id"] def _log_params(run_id: str, params: dict[str, str | int | float]) -> None: param_list = [{"key": k, "value": str(v)[:500]} for k, v in params.items()] _post("runs/log-batch", {"run_id": run_id, "params": param_list}) def _log_metrics(run_id: str, metrics: dict[str, float]) -> None: ts = int(time.time() * 1000) metric_list = [ {"key": k, "value": float(v), "timestamp": ts, "step": 0} for k, v in metrics.items() ] _post("runs/log-batch", {"run_id": run_id, "metrics": metric_list}) def _log_artifact(run_id: str, key: str, path: Path) -> None: """Log an artifact path as a tag (actual artifact upload requires artifact store). For local files, we record the path. For S3-promoted files, the caller should include the S3 URI in params instead. """ _post( "runs/log-batch", { "run_id": run_id, "tags": [{"key": f"artifact.{key}", "value": str(path)}], }, ) def _end_run(run_id: str) -> None: _post( "runs/update", { "run_id": run_id, "status": "FINISHED", "end_time": int(time.time() * 1000), }, ) def log_generation( avatar_name: str, params: dict[str, str | int | float] | None = None, metrics: dict[str, float] | None = None, artifacts: dict[str, Path] | None = None, tags: dict[str, str] | None = None, experiment_name: str = EXPERIMENT_NAME, ) -> str | None: """Log a complete avatar generation run to MLflow. Returns the MLflow run_id on success, None on failure. """ try: exp_id = _get_or_create_experiment(experiment_name) run_id = _create_run(exp_id, run_name=avatar_name, tags=tags or {}) if params: _log_params(run_id, params) if metrics: _log_metrics(run_id, metrics) if artifacts: for key, path in artifacts.items(): _log_artifact(run_id, key, path) _end_run(run_id) logger.info("Logged to MLflow: run_id=%s, experiment=%s", run_id, experiment_name) return run_id except Exception: logger.warning("MLflow logging failed — generation still succeeded", exc_info=True) return None