#!/usr/bin/env python3 """ MLflow-Integrated Document Ingestion Pipeline Enhanced version of document_ingestion_pipeline.py with full MLflow experiment tracking for metrics, parameters, and artifacts. Usage: pip install kfp==2.12.1 mlflow>=2.10.0 python document_ingestion_mlflow_pipeline.py # Upload to Kubeflow Pipelines UI """ from kfp import dsl from kfp import compiler from typing import NamedTuple # MLflow component configuration MLFLOW_IMAGE = "python:3.13-slim" MLFLOW_PACKAGES = [ "mlflow>=2.10.0", "psycopg2-binary>=2.9.0", "httpx", "beautifulsoup4", "pypdf2", "docx2txt", "tiktoken", "pymilvus", ] @dsl.component( base_image=MLFLOW_IMAGE, packages_to_install=["mlflow>=2.10.0", "psycopg2-binary"] ) def start_mlflow_run( experiment_name: str, run_name: str, mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80", pipeline_params: dict = None, ) -> NamedTuple('RunInfo', [('run_id', str), ('experiment_id', str)]): """Start an MLflow run and log initial parameters.""" import os import mlflow from mlflow.tracking import MlflowClient from collections import namedtuple mlflow.set_tracking_uri(mlflow_tracking_uri) client = MlflowClient() # Get or create experiment experiment = client.get_experiment_by_name(experiment_name) if experiment is None: experiment_id = client.create_experiment( name=experiment_name, artifact_location=f"/mlflow/artifacts/{experiment_name}" ) else: experiment_id = experiment.experiment_id # Start run with tags tags = { "pipeline.type": "document-ingestion", "pipeline.framework": "kubeflow", "kfp.run_id": os.environ.get("KFP_RUN_ID", "unknown"), } run = mlflow.start_run( experiment_id=experiment_id, run_name=run_name, tags=tags, ) # Log pipeline parameters if pipeline_params: for key, value in pipeline_params.items(): mlflow.log_param(key, str(value)[:500]) run_id = run.info.run_id mlflow.end_run() RunInfo = namedtuple('RunInfo', ['run_id', 'experiment_id']) return RunInfo(run_id, experiment_id) @dsl.component( base_image=MLFLOW_IMAGE, packages_to_install=MLFLOW_PACKAGES ) def extract_text_with_tracking( source_url: str, run_id: str, source_type: str = "auto", mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80", ) -> NamedTuple('ExtractionResult', [('text', str), ('char_count', int), ('source_type', str)]): """Extract text from documents with MLflow tracking.""" import time import httpx import mlflow from mlflow.tracking import MlflowClient from pathlib import Path import tempfile from collections import namedtuple mlflow.set_tracking_uri(mlflow_tracking_uri) client = MlflowClient() start_time = time.time() # Download content with httpx.Client(timeout=120.0) as http_client: response = http_client.get(source_url) content = response.content # Detect type if auto if source_type == "auto": if source_url.endswith(".pdf"): source_type = "pdf" elif source_url.endswith(".docx"): source_type = "docx" elif source_url.endswith(".html") or source_url.endswith(".htm"): source_type = "html" else: source_type = "text" # Extract text based on type if source_type == "pdf": from PyPDF2 import PdfReader import io reader = PdfReader(io.BytesIO(content)) text = "\n\n".join([page.extract_text() for page in reader.pages]) elif source_type == "docx": import docx2txt with tempfile.NamedTemporaryFile(suffix=".docx", delete=False) as f: f.write(content) text = docx2txt.process(f.name) elif source_type == "html": from bs4 import BeautifulSoup soup = BeautifulSoup(content, "html.parser") text = soup.get_text(separator="\n") else: text = content.decode("utf-8", errors="ignore") extraction_time = time.time() - start_time # Log to MLflow client.log_param(run_id, "source_url", source_url[:500]) client.log_param(run_id, "source_type", source_type) client.log_metric(run_id, "extraction_time_seconds", extraction_time) client.log_metric(run_id, "source_size_bytes", len(content)) client.log_metric(run_id, "extracted_char_count", len(text)) ExtractionResult = namedtuple('ExtractionResult', ['text', 'char_count', 'source_type']) return ExtractionResult(text, len(text), source_type) @dsl.component( base_image=MLFLOW_IMAGE, packages_to_install=MLFLOW_PACKAGES ) def chunk_text_with_tracking( text: str, run_id: str, chunk_size: int = 500, overlap: int = 50, mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80", ) -> NamedTuple('ChunkResult', [('chunks', list), ('chunk_count', int)]): """Split text into chunks with MLflow tracking.""" import time import tiktoken import mlflow from mlflow.tracking import MlflowClient from collections import namedtuple mlflow.set_tracking_uri(mlflow_tracking_uri) client = MlflowClient() start_time = time.time() enc = tiktoken.get_encoding("cl100k_base") tokens = enc.encode(text) chunks = [] start = 0 while start < len(tokens): end = min(start + chunk_size, len(tokens)) chunk_tokens = tokens[start:end] chunk_text = enc.decode(chunk_tokens) chunks.append({ "text": chunk_text, "start_token": start, "end_token": end, "token_count": len(chunk_tokens), }) start += chunk_size - overlap chunking_time = time.time() - start_time # Log to MLflow client.log_param(run_id, "chunk_size", chunk_size) client.log_param(run_id, "chunk_overlap", overlap) client.log_metric(run_id, "chunking_time_seconds", chunking_time) client.log_metric(run_id, "total_tokens", len(tokens)) client.log_metric(run_id, "chunks_created", len(chunks)) client.log_metric(run_id, "avg_tokens_per_chunk", len(tokens) / len(chunks) if chunks else 0) ChunkResult = namedtuple('ChunkResult', ['chunks', 'chunk_count']) return ChunkResult(chunks, len(chunks)) @dsl.component( base_image=MLFLOW_IMAGE, packages_to_install=MLFLOW_PACKAGES ) def generate_embeddings_with_tracking( chunks: list, run_id: str, embeddings_url: str = "http://ai-inference-serve-svc.ai-ml.svc.cluster.local:8000/embeddings", embeddings_model: str = "bge-small-en-v1.5", batch_size: int = 32, mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80", ) -> NamedTuple('EmbeddingResult', [('embedded_chunks', list), ('embedding_dim', int)]): """Generate embeddings with MLflow tracking.""" import time import httpx import mlflow from mlflow.tracking import MlflowClient from collections import namedtuple mlflow.set_tracking_uri(mlflow_tracking_uri) client = MlflowClient() start_time = time.time() embedded_chunks = [] texts = [c["text"] for c in chunks] embedding_dim = 0 with httpx.Client(timeout=300.0) as http_client: for i in range(0, len(texts), batch_size): batch = texts[i:i + batch_size] response = http_client.post( f"{embeddings_url}/embeddings", json={"input": batch, "model": embeddings_model} ) result = response.json() for j, embedding_data in enumerate(result["data"]): chunk = chunks[i + j].copy() chunk["embedding"] = embedding_data["embedding"] embedded_chunks.append(chunk) if embedding_dim == 0: embedding_dim = len(embedding_data["embedding"]) embedding_time = time.time() - start_time # Log to MLflow client.log_param(run_id, "embeddings_model", embeddings_model) client.log_param(run_id, "embedding_batch_size", batch_size) client.log_metric(run_id, "embedding_time_seconds", embedding_time) client.log_metric(run_id, "embedding_dimension", embedding_dim) client.log_metric(run_id, "embeddings_per_second", len(chunks) / embedding_time if embedding_time > 0 else 0) EmbeddingResult = namedtuple('EmbeddingResult', ['embedded_chunks', 'embedding_dim']) return EmbeddingResult(embedded_chunks, embedding_dim) @dsl.component( base_image=MLFLOW_IMAGE, packages_to_install=MLFLOW_PACKAGES ) def upsert_to_milvus_with_tracking( chunks: list, run_id: str, collection_name: str, source_name: str, milvus_host: str = "milvus.ai-ml.svc.cluster.local", mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80", ) -> NamedTuple('UpsertResult', [('inserted_count', int), ('collection_name', str)]): """Upsert to Milvus with MLflow tracking.""" import time import mlflow from mlflow.tracking import MlflowClient from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType, utility from collections import namedtuple mlflow.set_tracking_uri(mlflow_tracking_uri) client = MlflowClient() start_time = time.time() connections.connect(host=milvus_host, port=19530) # Create collection if needed if not utility.has_collection(collection_name): fields = [ FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True), FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=8192), FieldSchema(name="source", dtype=DataType.VARCHAR, max_length=512), FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=len(chunks[0]["embedding"])) ] schema = CollectionSchema(fields, description="Document embeddings") Collection(name=collection_name, schema=schema) collection = Collection(collection_name) # Prepare and insert data data = [ [chunk["text"] for chunk in chunks], [source_name for _ in chunks], [chunk["embedding"] for chunk in chunks] ] result = collection.insert(data) collection.flush() # Create index if needed if not collection.has_index(): collection.create_index( field_name="embedding", index_params={ "metric_type": "COSINE", "index_type": "HNSW", "params": {"M": 16, "efConstruction": 256} } ) upsert_time = time.time() - start_time inserted_count = len(chunks) # Log to MLflow client.log_param(run_id, "milvus_collection", collection_name) client.log_param(run_id, "source_name", source_name[:500]) client.log_metric(run_id, "upsert_time_seconds", upsert_time) client.log_metric(run_id, "documents_inserted", inserted_count) client.log_metric(run_id, "inserts_per_second", inserted_count / upsert_time if upsert_time > 0 else 0) connections.disconnect("default") UpsertResult = namedtuple('UpsertResult', ['inserted_count', 'collection_name']) return UpsertResult(inserted_count, collection_name) @dsl.component( base_image=MLFLOW_IMAGE, packages_to_install=["mlflow>=2.10.0", "psycopg2-binary"] ) def finalize_mlflow_run( run_id: str, total_chunks: int, status: str = "FINISHED", mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80", ) -> str: """Finalize the MLflow run with summary metrics.""" import mlflow from mlflow.tracking import MlflowClient from mlflow.entities import RunStatus mlflow.set_tracking_uri(mlflow_tracking_uri) client = MlflowClient() # Set final status status_map = { "FINISHED": RunStatus.FINISHED, "FAILED": RunStatus.FAILED, } run_status = status_map.get(status.upper(), RunStatus.FINISHED) # Log summary client.set_tag(run_id, "pipeline.status", status) client.log_metric(run_id, "final_chunk_count", total_chunks) client.set_terminated(run_id, status=run_status) return run_id @dsl.pipeline( name="mlflow-document-ingestion", description="Document ingestion with MLflow experiment tracking" ) def document_ingestion_mlflow_pipeline( source_url: str, collection_name: str = "knowledge_base", source_name: str = "", chunk_size: int = 500, chunk_overlap: int = 50, experiment_name: str = "document-ingestion", run_name: str = "", mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80", ): """ Document Ingestion Pipeline with MLflow Tracking Ingests documents into Milvus while tracking all metrics and parameters in MLflow for experiment comparison. """ import time # Generate run name if not provided effective_run_name = run_name or f"ingestion-{int(time.time())}" # Start MLflow run mlflow_run = start_mlflow_run( experiment_name=experiment_name, run_name=effective_run_name, mlflow_tracking_uri=mlflow_tracking_uri, pipeline_params={ "source_url": source_url, "collection_name": collection_name, "chunk_size": chunk_size, "chunk_overlap": chunk_overlap, }, ) # Extract text extraction = extract_text_with_tracking( source_url=source_url, run_id=mlflow_run.outputs["run_id"], mlflow_tracking_uri=mlflow_tracking_uri, ) extraction.set_caching_options(enable_caching=True) # Chunk text chunking = chunk_text_with_tracking( text=extraction.outputs["text"], run_id=mlflow_run.outputs["run_id"], chunk_size=chunk_size, overlap=chunk_overlap, mlflow_tracking_uri=mlflow_tracking_uri, ) chunking.set_caching_options(enable_caching=True) # Generate embeddings embedding = generate_embeddings_with_tracking( chunks=chunking.outputs["chunks"], run_id=mlflow_run.outputs["run_id"], mlflow_tracking_uri=mlflow_tracking_uri, ) embedding.set_caching_options(enable_caching=True) # Upsert to Milvus upsert = upsert_to_milvus_with_tracking( chunks=embedding.outputs["embedded_chunks"], run_id=mlflow_run.outputs["run_id"], collection_name=collection_name, source_name=source_name or source_url, mlflow_tracking_uri=mlflow_tracking_uri, ) # Finalize run finalize_mlflow_run( run_id=mlflow_run.outputs["run_id"], total_chunks=upsert.outputs["inserted_count"], mlflow_tracking_uri=mlflow_tracking_uri, ) if __name__ == "__main__": compiler.Compiler().compile( document_ingestion_mlflow_pipeline, "document_ingestion_mlflow_pipeline.yaml" ) print("Compiled: document_ingestion_mlflow_pipeline.yaml")