style: apply ruff format to all files
Some checks failed
CI / Lint (push) Successful in 1m46s
CI / Test (push) Successful in 1m44s
CI / Publish (push) Failing after 19s
CI / Notify (push) Successful in 1s

This commit is contained in:
2026-02-13 11:05:26 -05:00
parent 1c841729a0
commit ca5bef9664
7 changed files with 89 additions and 222 deletions

View File

@@ -63,10 +63,7 @@ def cmd_list_experiments(args):
def cmd_compare(args): def cmd_compare(args):
"""Compare recent runs in an experiment.""" """Compare recent runs in an experiment."""
analyzer = ExperimentAnalyzer( analyzer = ExperimentAnalyzer(args.experiment, tracking_uri=args.tracking_uri)
args.experiment,
tracking_uri=args.tracking_uri
)
if args.run_ids: if args.run_ids:
run_ids = args.run_ids.split(",") run_ids = args.run_ids.split(",")
@@ -82,10 +79,7 @@ def cmd_compare(args):
def cmd_best(args): def cmd_best(args):
"""Find the best run by a metric.""" """Find the best run by a metric."""
analyzer = ExperimentAnalyzer( analyzer = ExperimentAnalyzer(args.experiment, tracking_uri=args.tracking_uri)
args.experiment,
tracking_uri=args.tracking_uri
)
best_run = analyzer.get_best_run( best_run = analyzer.get_best_run(
metric=args.metric, metric=args.metric,
@@ -115,10 +109,7 @@ def cmd_best(args):
def cmd_summary(args): def cmd_summary(args):
"""Get metrics summary for an experiment.""" """Get metrics summary for an experiment."""
analyzer = ExperimentAnalyzer( analyzer = ExperimentAnalyzer(args.experiment, tracking_uri=args.tracking_uri)
args.experiment,
tracking_uri=args.tracking_uri
)
summary = analyzer.get_metrics_summary( summary = analyzer.get_metrics_summary(
hours=args.hours, hours=args.hours,
@@ -201,10 +192,7 @@ def cmd_promote(args):
def cmd_query(args): def cmd_query(args):
"""Query runs with a filter.""" """Query runs with a filter."""
analyzer = ExperimentAnalyzer( analyzer = ExperimentAnalyzer(args.experiment, tracking_uri=args.tracking_uri)
args.experiment,
tracking_uri=args.tracking_uri
)
runs = analyzer.search_runs( runs = analyzer.search_runs(
filter_string=args.filter or "", filter_string=args.filter or "",

View File

@@ -22,41 +22,24 @@ class MLflowConfig:
# Tracking server URIs # Tracking server URIs
tracking_uri: str = field( tracking_uri: str = field(
default_factory=lambda: os.environ.get( default_factory=lambda: os.environ.get("MLFLOW_TRACKING_URI", "http://mlflow.mlflow.svc.cluster.local:80")
"MLFLOW_TRACKING_URI",
"http://mlflow.mlflow.svc.cluster.local:80"
)
) )
external_uri: str = field( external_uri: str = field(
default_factory=lambda: os.environ.get( default_factory=lambda: os.environ.get("MLFLOW_EXTERNAL_URI", "https://mlflow.lab.daviestechlabs.io")
"MLFLOW_EXTERNAL_URI",
"https://mlflow.lab.daviestechlabs.io"
)
) )
# Artifact storage (NFS PVC mount) # Artifact storage (NFS PVC mount)
artifact_location: str = field( artifact_location: str = field(
default_factory=lambda: os.environ.get( default_factory=lambda: os.environ.get("MLFLOW_ARTIFACT_LOCATION", "/mlflow/artifacts")
"MLFLOW_ARTIFACT_LOCATION",
"/mlflow/artifacts"
)
) )
# Default experiment settings # Default experiment settings
default_experiment: str = field( default_experiment: str = field(
default_factory=lambda: os.environ.get( default_factory=lambda: os.environ.get("MLFLOW_DEFAULT_EXPERIMENT", "llm-workflows")
"MLFLOW_DEFAULT_EXPERIMENT",
"llm-workflows"
)
) )
# Service identification # Service identification
service_name: str = field( service_name: str = field(default_factory=lambda: os.environ.get("OTEL_SERVICE_NAME", "unknown-service"))
default_factory=lambda: os.environ.get(
"OTEL_SERVICE_NAME",
"unknown-service"
)
)
# Additional tags to add to all runs # Additional tags to add to all runs
default_tags: Dict[str, str] = field(default_factory=dict) default_tags: Dict[str, str] = field(default_factory=dict)
@@ -85,10 +68,7 @@ def get_tracking_uri(external: bool = False) -> str:
return config.external_uri if external else config.tracking_uri return config.external_uri if external else config.tracking_uri
def get_mlflow_client( def get_mlflow_client(tracking_uri: Optional[str] = None, configure_global: bool = True) -> MlflowClient:
tracking_uri: Optional[str] = None,
configure_global: bool = True
) -> MlflowClient:
""" """
Get a configured MLflow client. Get a configured MLflow client.
@@ -110,9 +90,7 @@ def get_mlflow_client(
def ensure_experiment( def ensure_experiment(
experiment_name: str, experiment_name: str, artifact_location: Optional[str] = None, tags: Optional[Dict[str, str]] = None
artifact_location: Optional[str] = None,
tags: Optional[Dict[str, str]] = None
) -> str: ) -> str:
""" """
Ensure an experiment exists, creating it if necessary. Ensure an experiment exists, creating it if necessary.
@@ -134,11 +112,7 @@ def ensure_experiment(
if experiment is None: if experiment is None:
# Create the experiment # Create the experiment
artifact_loc = artifact_location or f"{config.artifact_location}/{experiment_name}" artifact_loc = artifact_location or f"{config.artifact_location}/{experiment_name}"
experiment_id = client.create_experiment( experiment_id = client.create_experiment(name=experiment_name, artifact_location=artifact_loc, tags=tags or {})
name=experiment_name,
artifact_location=artifact_loc,
tags=tags or {}
)
logger.info(f"Created experiment '{experiment_name}' with ID: {experiment_id}") logger.info(f"Created experiment '{experiment_name}' with ID: {experiment_id}")
else: else:
experiment_id = experiment.experiment_id experiment_id = experiment.experiment_id
@@ -148,9 +122,7 @@ def ensure_experiment(
def get_or_create_registered_model( def get_or_create_registered_model(
model_name: str, model_name: str, description: Optional[str] = None, tags: Optional[Dict[str, str]] = None
description: Optional[str] = None,
tags: Optional[Dict[str, str]] = None
) -> str: ) -> str:
""" """
Get or create a registered model in the Model Registry. Get or create a registered model in the Model Registry.
@@ -172,9 +144,7 @@ def get_or_create_registered_model(
except mlflow.exceptions.MlflowException: except mlflow.exceptions.MlflowException:
# Create the model # Create the model
client.create_registered_model( client.create_registered_model(
name=model_name, name=model_name, description=description or f"Model for {model_name}", tags=tags or {}
description=description or f"Model for {model_name}",
tags=tags or {}
) )
logger.info(f"Created registered model: {model_name}") logger.info(f"Created registered model: {model_name}")

View File

@@ -51,6 +51,7 @@ logger = logging.getLogger(__name__)
@dataclass @dataclass
class RunComparison: class RunComparison:
"""Comparison result for multiple MLflow runs.""" """Comparison result for multiple MLflow runs."""
run_ids: List[str] run_ids: List[str]
experiment_name: str experiment_name: str
@@ -111,6 +112,7 @@ class RunComparison:
@dataclass @dataclass
class PromotionRecommendation: class PromotionRecommendation:
"""Recommendation for model promotion.""" """Recommendation for model promotion."""
model_name: str model_name: str
version: Optional[int] version: Optional[int]
recommended: bool recommended: bool
@@ -262,13 +264,9 @@ class ExperimentAnalyzer:
# Metadata # Metadata
comparison.run_names[run_id] = run.info.run_name or run_id[:8] comparison.run_names[run_id] = run.info.run_name or run_id[:8]
comparison.start_times[run_id] = datetime.fromtimestamp( comparison.start_times[run_id] = datetime.fromtimestamp(run.info.start_time / 1000)
run.info.start_time / 1000
)
if run.info.end_time: if run.info.end_time:
comparison.durations[run_id] = ( comparison.durations[run_id] = (run.info.end_time - run.info.start_time) / 1000
run.info.end_time - run.info.start_time
) / 1000
# Metrics # Metrics
for key, value in run.data.metrics.items(): for key, value in run.data.metrics.items():
@@ -288,10 +286,7 @@ class ExperimentAnalyzer:
continue continue
# Determine if lower is better based on metric name # Determine if lower is better based on metric name
minimize = any( minimize = any(term in metric_name.lower() for term in ["latency", "error", "loss", "time"])
term in metric_name.lower()
for term in ["latency", "error", "loss", "time"]
)
if minimize: if minimize:
best_id = min(values.keys(), key=lambda k: values[k]) best_id = min(values.keys(), key=lambda k: values[k])
@@ -330,10 +325,7 @@ class ExperimentAnalyzer:
) )
# Filter to only runs that have the metric # Filter to only runs that have the metric
runs_with_metric = [ runs_with_metric = [r for r in runs if metric in r.data.metrics]
r for r in runs
if metric in r.data.metrics
]
return runs_with_metric[0] if runs_with_metric else None return runs_with_metric[0] if runs_with_metric else None

View File

@@ -29,6 +29,7 @@ logger = logging.getLogger(__name__)
@dataclass @dataclass
class InferenceMetrics: class InferenceMetrics:
"""Metrics collected during an inference request.""" """Metrics collected during an inference request."""
request_id: str request_id: str
user_id: Optional[str] = None user_id: Optional[str] = None
session_id: Optional[str] = None session_id: Optional[str] = None
@@ -190,31 +191,22 @@ class InferenceMetricsTracker:
# Initialize MLflow in thread pool to avoid blocking # Initialize MLflow in thread pool to avoid blocking
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
await loop.run_in_executor( await loop.run_in_executor(self._executor, self._init_mlflow)
self._executor,
self._init_mlflow
)
if self.enable_batching: if self.enable_batching:
self._flush_task = asyncio.create_task(self._periodic_flush()) self._flush_task = asyncio.create_task(self._periodic_flush())
logger.info( logger.info(f"InferenceMetricsTracker started for {self.service_name} (experiment: {self.experiment_name})")
f"InferenceMetricsTracker started for {self.service_name} "
f"(experiment: {self.experiment_name})"
)
def _init_mlflow(self) -> None: def _init_mlflow(self) -> None:
"""Initialize MLflow client and experiment (runs in thread pool).""" """Initialize MLflow client and experiment (runs in thread pool)."""
self._client = get_mlflow_client( self._client = get_mlflow_client(tracking_uri=self.tracking_uri, configure_global=True)
tracking_uri=self.tracking_uri,
configure_global=True
)
self._experiment_id = ensure_experiment( self._experiment_id = ensure_experiment(
self.experiment_name, self.experiment_name,
tags={ tags={
"service": self.service_name, "service": self.service_name,
"type": "inference-metrics", "type": "inference-metrics",
} },
) )
async def stop(self) -> None: async def stop(self) -> None:
@@ -265,10 +257,7 @@ class InferenceMetricsTracker:
else: else:
# Immediate logging in thread pool # Immediate logging in thread pool
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
await loop.run_in_executor( await loop.run_in_executor(self._executor, partial(self._log_single_inference, metrics))
self._executor,
partial(self._log_single_inference, metrics)
)
async def _periodic_flush(self) -> None: async def _periodic_flush(self) -> None:
"""Periodically flush batched metrics.""" """Periodically flush batched metrics."""
@@ -287,10 +276,7 @@ class InferenceMetricsTracker:
# Log in thread pool # Log in thread pool
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
await loop.run_in_executor( await loop.run_in_executor(self._executor, partial(self._log_batch, batch))
self._executor,
partial(self._log_batch, batch)
)
def _log_single_inference(self, metrics: InferenceMetrics) -> None: def _log_single_inference(self, metrics: InferenceMetrics) -> None:
"""Log a single inference request to MLflow (runs in thread pool).""" """Log a single inference request to MLflow (runs in thread pool)."""
@@ -302,7 +288,7 @@ class InferenceMetricsTracker:
"service": self.service_name, "service": self.service_name,
"request_id": metrics.request_id, "request_id": metrics.request_id,
"type": "single-inference", "type": "single-inference",
} },
): ):
mlflow.log_params(metrics.as_params_dict()) mlflow.log_params(metrics.as_params_dict())
mlflow.log_metrics(metrics.as_metrics_dict()) mlflow.log_metrics(metrics.as_metrics_dict())
@@ -336,7 +322,7 @@ class InferenceMetricsTracker:
"service": self.service_name, "service": self.service_name,
"type": "batch-inference", "type": "batch-inference",
"batch_size": str(len(batch)), "batch_size": str(len(batch)),
} },
): ):
# Log aggregate metrics # Log aggregate metrics
mlflow.log_metrics(aggregates) mlflow.log_metrics(aggregates)
@@ -352,12 +338,14 @@ class InferenceMetricsTracker:
premium_count = sum(1 for m in batch if m.is_premium) premium_count = sum(1 for m in batch if m.is_premium)
error_count = sum(1 for m in batch if m.has_error) error_count = sum(1 for m in batch if m.has_error)
mlflow.log_metrics({ mlflow.log_metrics(
"rag_enabled_pct": rag_enabled_count / len(batch) * 100, {
"streaming_pct": streaming_count / len(batch) * 100, "rag_enabled_pct": rag_enabled_count / len(batch) * 100,
"premium_pct": premium_count / len(batch) * 100, "streaming_pct": streaming_count / len(batch) * 100,
"error_rate": error_count / len(batch) * 100, "premium_pct": premium_count / len(batch) * 100,
}) "error_rate": error_count / len(batch) * 100,
}
)
# Log model distribution # Log model distribution
model_counts: Dict[str, int] = defaultdict(int) model_counts: Dict[str, int] = defaultdict(int)
@@ -366,20 +354,14 @@ class InferenceMetricsTracker:
model_counts[m.model_name] += 1 model_counts[m.model_name] += 1
if model_counts: if model_counts:
mlflow.log_dict( mlflow.log_dict({"models": dict(model_counts)}, "model_distribution.json")
{"models": dict(model_counts)},
"model_distribution.json"
)
logger.info(f"Logged batch of {len(batch)} inference metrics") logger.info(f"Logged batch of {len(batch)} inference metrics")
except Exception as e: except Exception as e:
logger.error(f"Failed to log batch metrics: {e}") logger.error(f"Failed to log batch metrics: {e}")
def _calculate_aggregates( def _calculate_aggregates(self, batch: List[InferenceMetrics]) -> Dict[str, float]:
self,
batch: List[InferenceMetrics]
) -> Dict[str, float]:
"""Calculate aggregate statistics from a batch of metrics.""" """Calculate aggregate statistics from a batch of metrics."""
import statistics import statistics

View File

@@ -47,17 +47,14 @@ MLFLOW_PACKAGES = [
] ]
@dsl.component( @dsl.component(base_image=MLFLOW_IMAGE, packages_to_install=MLFLOW_PACKAGES)
base_image=MLFLOW_IMAGE,
packages_to_install=MLFLOW_PACKAGES
)
def create_mlflow_run( def create_mlflow_run(
experiment_name: str, experiment_name: str,
run_name: str, run_name: str,
mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80", mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80",
tags: Dict[str, str] = None, tags: Dict[str, str] = None,
params: Dict[str, str] = None, params: Dict[str, str] = None,
) -> NamedTuple('RunInfo', [('run_id', str), ('experiment_id', str), ('artifact_uri', str)]): ) -> NamedTuple("RunInfo", [("run_id", str), ("experiment_id", str), ("artifact_uri", str)]):
""" """
Create a new MLflow run for the pipeline. Create a new MLflow run for the pipeline.
@@ -90,8 +87,7 @@ def create_mlflow_run(
experiment = client.get_experiment_by_name(experiment_name) experiment = client.get_experiment_by_name(experiment_name)
if experiment is None: if experiment is None:
experiment_id = client.create_experiment( experiment_id = client.create_experiment(
name=experiment_name, name=experiment_name, artifact_location=f"/mlflow/artifacts/{experiment_name}"
artifact_location=f"/mlflow/artifacts/{experiment_name}"
) )
else: else:
experiment_id = experiment.experiment_id experiment_id = experiment.experiment_id
@@ -122,14 +118,11 @@ def create_mlflow_run(
# End run (KFP components are isolated, we'll resume in other components) # End run (KFP components are isolated, we'll resume in other components)
mlflow.end_run() mlflow.end_run()
RunInfo = namedtuple('RunInfo', ['run_id', 'experiment_id', 'artifact_uri']) RunInfo = namedtuple("RunInfo", ["run_id", "experiment_id", "artifact_uri"])
return RunInfo(run_id, experiment_id, artifact_uri) return RunInfo(run_id, experiment_id, artifact_uri)
@dsl.component( @dsl.component(base_image=MLFLOW_IMAGE, packages_to_install=MLFLOW_PACKAGES)
base_image=MLFLOW_IMAGE,
packages_to_install=MLFLOW_PACKAGES
)
def log_params_component( def log_params_component(
run_id: str, run_id: str,
params: Dict[str, str], params: Dict[str, str],
@@ -158,10 +151,7 @@ def log_params_component(
return run_id return run_id
@dsl.component( @dsl.component(base_image=MLFLOW_IMAGE, packages_to_install=MLFLOW_PACKAGES)
base_image=MLFLOW_IMAGE,
packages_to_install=MLFLOW_PACKAGES
)
def log_metrics_component( def log_metrics_component(
run_id: str, run_id: str,
metrics: Dict[str, float], metrics: Dict[str, float],
@@ -192,10 +182,7 @@ def log_metrics_component(
return run_id return run_id
@dsl.component( @dsl.component(base_image=MLFLOW_IMAGE, packages_to_install=MLFLOW_PACKAGES)
base_image=MLFLOW_IMAGE,
packages_to_install=MLFLOW_PACKAGES
)
def log_artifact_component( def log_artifact_component(
run_id: str, run_id: str,
artifact_path: str, artifact_path: str,
@@ -225,10 +212,7 @@ def log_artifact_component(
return run_id return run_id
@dsl.component( @dsl.component(base_image=MLFLOW_IMAGE, packages_to_install=MLFLOW_PACKAGES)
base_image=MLFLOW_IMAGE,
packages_to_install=MLFLOW_PACKAGES
)
def log_dict_artifact( def log_dict_artifact(
run_id: str, run_id: str,
data: Dict[str, Any], data: Dict[str, Any],
@@ -258,23 +242,20 @@ def log_dict_artifact(
client = MlflowClient() client = MlflowClient()
# Ensure .json extension # Ensure .json extension
if not filename.endswith('.json'): if not filename.endswith(".json"):
filename += '.json' filename += ".json"
# Write to temp file and log # Write to temp file and log
with tempfile.TemporaryDirectory() as tmpdir: with tempfile.TemporaryDirectory() as tmpdir:
filepath = Path(tmpdir) / filename filepath = Path(tmpdir) / filename
with open(filepath, 'w') as f: with open(filepath, "w") as f:
json.dump(data, f, indent=2) json.dump(data, f, indent=2)
client.log_artifact(run_id, str(filepath)) client.log_artifact(run_id, str(filepath))
return run_id return run_id
@dsl.component( @dsl.component(base_image=MLFLOW_IMAGE, packages_to_install=MLFLOW_PACKAGES)
base_image=MLFLOW_IMAGE,
packages_to_install=MLFLOW_PACKAGES
)
def end_mlflow_run( def end_mlflow_run(
run_id: str, run_id: str,
status: str = "FINISHED", status: str = "FINISHED",
@@ -310,10 +291,7 @@ def end_mlflow_run(
return run_id return run_id
@dsl.component( @dsl.component(base_image=MLFLOW_IMAGE, packages_to_install=MLFLOW_PACKAGES + ["httpx"])
base_image=MLFLOW_IMAGE,
packages_to_install=MLFLOW_PACKAGES + ["httpx"]
)
def log_training_metrics( def log_training_metrics(
run_id: str, run_id: str,
model_type: str, model_type: str,
@@ -370,7 +348,7 @@ def log_training_metrics(
# Log full config as artifact # Log full config as artifact
with tempfile.TemporaryDirectory() as tmpdir: with tempfile.TemporaryDirectory() as tmpdir:
config_path = Path(tmpdir) / "training_config.json" config_path = Path(tmpdir) / "training_config.json"
with open(config_path, 'w') as f: with open(config_path, "w") as f:
json.dump(training_config, f, indent=2) json.dump(training_config, f, indent=2)
client.log_artifact(run_id, str(config_path)) client.log_artifact(run_id, str(config_path))
@@ -382,10 +360,7 @@ def log_training_metrics(
return run_id return run_id
@dsl.component( @dsl.component(base_image=MLFLOW_IMAGE, packages_to_install=MLFLOW_PACKAGES)
base_image=MLFLOW_IMAGE,
packages_to_install=MLFLOW_PACKAGES
)
def log_document_ingestion_metrics( def log_document_ingestion_metrics(
run_id: str, run_id: str,
source_url: str, source_url: str,
@@ -452,10 +427,7 @@ def log_document_ingestion_metrics(
return run_id return run_id
@dsl.component( @dsl.component(base_image=MLFLOW_IMAGE, packages_to_install=MLFLOW_PACKAGES)
base_image=MLFLOW_IMAGE,
packages_to_install=MLFLOW_PACKAGES
)
def log_evaluation_results( def log_evaluation_results(
run_id: str, run_id: str,
model_name: str, model_name: str,
@@ -502,7 +474,7 @@ def log_evaluation_results(
if sample_results: if sample_results:
with tempfile.TemporaryDirectory() as tmpdir: with tempfile.TemporaryDirectory() as tmpdir:
results_path = Path(tmpdir) / "evaluation_results.json" results_path = Path(tmpdir) / "evaluation_results.json"
with open(results_path, 'w') as f: with open(results_path, "w") as f:
json.dump(sample_results, f, indent=2) json.dump(sample_results, f, indent=2)
client.log_artifact(run_id, str(results_path)) client.log_artifact(run_id, str(results_path))

View File

@@ -192,7 +192,7 @@ def register_model_for_kserve(
tags={ tags={
"model.type": model_type, "model.type": model_type,
"deployment.target": "kserve", "deployment.target": "kserve",
} },
) )
logger.info(f"Created registered model: {model_name}") logger.info(f"Created registered model: {model_name}")
@@ -206,13 +206,10 @@ def register_model_for_kserve(
**(tags or {}), **(tags or {}),
"model.type": model_type, "model.type": model_type,
**kserve_config.as_dict(), **kserve_config.as_dict(),
} },
) )
logger.info( logger.info(f"Registered model version {model_version.version} for {model_name} (type: {model_type})")
f"Registered model version {model_version.version} "
f"for {model_name} (type: {model_type})"
)
return model_version return model_version
@@ -412,25 +409,24 @@ def generate_kserve_manifest(
if config.container_image: if config.container_image:
# Custom container # Custom container
predictor["containers"] = [{ predictor["containers"] = [
"name": "predictor", {
"image": config.container_image, "name": "predictor",
"ports": [{"containerPort": config.container_port, "protocol": "TCP"}], "image": config.container_image,
"resources": { "ports": [{"containerPort": config.container_port, "protocol": "TCP"}],
"requests": { "resources": {
"cpu": config.cpu_request, "requests": {
"memory": config.memory_request, "cpu": config.cpu_request,
"memory": config.memory_request,
},
"limits": {
"cpu": config.cpu_limit,
"memory": config.memory_limit,
},
}, },
"limits": { "env": [{"name": k, "value": v} for k, v in config.env_vars.items()],
"cpu": config.cpu_limit, }
"memory": config.memory_limit, ]
},
},
"env": [
{"name": k, "value": v}
for k, v in config.env_vars.items()
],
}]
# Add GPU if needed # Add GPU if needed
if config.gpu_count > 0: if config.gpu_count > 0:
@@ -494,7 +490,7 @@ def generate_kserve_yaml(
yaml_str = yaml.dump(manifest, default_flow_style=False, sort_keys=False) yaml_str = yaml.dump(manifest, default_flow_style=False, sort_keys=False)
if output_path: if output_path:
with open(output_path, 'w') as f: with open(output_path, "w") as f:
f.write(yaml_str) f.write(yaml_str)
logger.info(f"Wrote KServe manifest to {output_path}") logger.info(f"Wrote KServe manifest to {output_path}")

View File

@@ -23,6 +23,7 @@ logger = logging.getLogger(__name__)
@dataclass @dataclass
class PipelineMetadata: class PipelineMetadata:
"""Metadata about the Kubeflow Pipeline run.""" """Metadata about the Kubeflow Pipeline run."""
pipeline_name: str pipeline_name: str
run_id: str run_id: str
run_name: Optional[str] = None run_name: Optional[str] = None
@@ -30,12 +31,8 @@ class PipelineMetadata:
namespace: str = "ai-ml" namespace: str = "ai-ml"
# KFP-specific metadata (populated from environment if available) # KFP-specific metadata (populated from environment if available)
kfp_run_id: Optional[str] = field( kfp_run_id: Optional[str] = field(default_factory=lambda: os.environ.get("KFP_RUN_ID"))
default_factory=lambda: os.environ.get("KFP_RUN_ID") kfp_pod_name: Optional[str] = field(default_factory=lambda: os.environ.get("KFP_POD_NAME"))
)
kfp_pod_name: Optional[str] = field(
default_factory=lambda: os.environ.get("KFP_POD_NAME")
)
def as_tags(self) -> Dict[str, str]: def as_tags(self) -> Dict[str, str]:
"""Convert metadata to MLflow tags.""" """Convert metadata to MLflow tags."""
@@ -142,10 +139,7 @@ class MLflowTracker:
Yields: Yields:
The MLflow run object The MLflow run object
""" """
self.client = get_mlflow_client( self.client = get_mlflow_client(tracking_uri=self.tracking_uri, configure_global=True)
tracking_uri=self.tracking_uri,
configure_global=True
)
# Ensure experiment exists # Ensure experiment exists
experiment_id = ensure_experiment(self.experiment_name) experiment_id = ensure_experiment(self.experiment_name)
@@ -163,8 +157,7 @@ class MLflowTracker:
self.run_id = self.run.info.run_id self.run_id = self.run.info.run_id
logger.info( logger.info(
f"Started MLflow run '{self.run_name}' " f"Started MLflow run '{self.run_name}' (ID: {self.run_id}) in experiment '{self.experiment_name}'"
f"(ID: {self.run_id}) in experiment '{self.experiment_name}'"
) )
yield self.run yield self.run
@@ -214,11 +207,7 @@ class MLflowTracker:
"""Log a single parameter.""" """Log a single parameter."""
self.log_params({key: value}) self.log_params({key: value})
def log_metrics( def log_metrics(self, metrics: Dict[str, Union[float, int]], step: Optional[int] = None) -> None:
self,
metrics: Dict[str, Union[float, int]],
step: Optional[int] = None
) -> None:
""" """
Log metrics to the current run. Log metrics to the current run.
@@ -233,20 +222,11 @@ class MLflowTracker:
mlflow.log_metrics(metrics, step=step) mlflow.log_metrics(metrics, step=step)
logger.debug(f"Logged {len(metrics)} metrics") logger.debug(f"Logged {len(metrics)} metrics")
def log_metric( def log_metric(self, key: str, value: Union[float, int], step: Optional[int] = None) -> None:
self,
key: str,
value: Union[float, int],
step: Optional[int] = None
) -> None:
"""Log a single metric.""" """Log a single metric."""
self.log_metrics({key: value}, step=step) self.log_metrics({key: value}, step=step)
def log_artifact( def log_artifact(self, local_path: str, artifact_path: Optional[str] = None) -> None:
self,
local_path: str,
artifact_path: Optional[str] = None
) -> None:
""" """
Log an artifact file to the current run. Log an artifact file to the current run.
@@ -261,11 +241,7 @@ class MLflowTracker:
mlflow.log_artifact(local_path, artifact_path) mlflow.log_artifact(local_path, artifact_path)
logger.info(f"Logged artifact: {local_path}") logger.info(f"Logged artifact: {local_path}")
def log_artifacts( def log_artifacts(self, local_dir: str, artifact_path: Optional[str] = None) -> None:
self,
local_dir: str,
artifact_path: Optional[str] = None
) -> None:
""" """
Log all files in a directory as artifacts. Log all files in a directory as artifacts.
@@ -280,12 +256,7 @@ class MLflowTracker:
mlflow.log_artifacts(local_dir, artifact_path) mlflow.log_artifacts(local_dir, artifact_path)
logger.info(f"Logged artifacts from: {local_dir}") logger.info(f"Logged artifacts from: {local_dir}")
def log_dict( def log_dict(self, data: Dict[str, Any], filename: str, artifact_path: Optional[str] = None) -> None:
self,
data: Dict[str, Any],
filename: str,
artifact_path: Optional[str] = None
) -> None:
""" """
Log a dictionary as a JSON artifact. Log a dictionary as a JSON artifact.
@@ -311,7 +282,7 @@ class MLflowTracker:
model_name: str, model_name: str,
model_path: Optional[str] = None, model_path: Optional[str] = None,
framework: str = "pytorch", framework: str = "pytorch",
extra_info: Optional[Dict[str, Any]] = None extra_info: Optional[Dict[str, Any]] = None,
) -> None: ) -> None:
""" """
Log model information as parameters and tags. Log model information as parameters and tags.
@@ -341,11 +312,7 @@ class MLflowTracker:
mlflow.set_tag("model.name", model_name) mlflow.set_tag("model.name", model_name)
def log_dataset_info( def log_dataset_info(
self, self, name: str, source: str, size: Optional[int] = None, extra_info: Optional[Dict[str, Any]] = None
name: str,
source: str,
size: Optional[int] = None,
extra_info: Optional[Dict[str, Any]] = None
) -> None: ) -> None:
""" """
Log dataset information. Log dataset information.