Files
kubeflow/document_ingestion_mlflow_pipeline.py
Billy D. c26e4e5ef0 feat: Add Kubeflow Pipeline definitions
- voice_pipeline: STT → RAG → LLM → TTS
- document_ingestion_pipeline: Extract → Chunk → Embed → Milvus
- document_ingestion_mlflow_pipeline: With MLflow tracking
- evaluation_pipeline: Model benchmarking
- kfp-sync-job: K8s job to sync pipelines
2026-02-01 20:41:13 -05:00

467 lines
15 KiB
Python

#!/usr/bin/env python3
"""
MLflow-Integrated Document Ingestion Pipeline
Enhanced version of document_ingestion_pipeline.py with full MLflow
experiment tracking for metrics, parameters, and artifacts.
Usage:
pip install kfp==2.12.1 mlflow>=2.10.0
python document_ingestion_mlflow_pipeline.py
# Upload to Kubeflow Pipelines UI
"""
from kfp import dsl
from kfp import compiler
from typing import NamedTuple
# MLflow component configuration
MLFLOW_IMAGE = "python:3.13-slim"
MLFLOW_PACKAGES = [
"mlflow>=2.10.0",
"psycopg2-binary>=2.9.0",
"httpx",
"beautifulsoup4",
"pypdf2",
"docx2txt",
"tiktoken",
"pymilvus",
]
@dsl.component(
base_image=MLFLOW_IMAGE,
packages_to_install=["mlflow>=2.10.0", "psycopg2-binary"]
)
def start_mlflow_run(
experiment_name: str,
run_name: str,
mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80",
pipeline_params: dict = None,
) -> NamedTuple('RunInfo', [('run_id', str), ('experiment_id', str)]):
"""Start an MLflow run and log initial parameters."""
import os
import mlflow
from mlflow.tracking import MlflowClient
from collections import namedtuple
mlflow.set_tracking_uri(mlflow_tracking_uri)
client = MlflowClient()
# Get or create experiment
experiment = client.get_experiment_by_name(experiment_name)
if experiment is None:
experiment_id = client.create_experiment(
name=experiment_name,
artifact_location=f"/mlflow/artifacts/{experiment_name}"
)
else:
experiment_id = experiment.experiment_id
# Start run with tags
tags = {
"pipeline.type": "document-ingestion",
"pipeline.framework": "kubeflow",
"kfp.run_id": os.environ.get("KFP_RUN_ID", "unknown"),
}
run = mlflow.start_run(
experiment_id=experiment_id,
run_name=run_name,
tags=tags,
)
# Log pipeline parameters
if pipeline_params:
for key, value in pipeline_params.items():
mlflow.log_param(key, str(value)[:500])
run_id = run.info.run_id
mlflow.end_run()
RunInfo = namedtuple('RunInfo', ['run_id', 'experiment_id'])
return RunInfo(run_id, experiment_id)
@dsl.component(
base_image=MLFLOW_IMAGE,
packages_to_install=MLFLOW_PACKAGES
)
def extract_text_with_tracking(
source_url: str,
run_id: str,
source_type: str = "auto",
mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80",
) -> NamedTuple('ExtractionResult', [('text', str), ('char_count', int), ('source_type', str)]):
"""Extract text from documents with MLflow tracking."""
import time
import httpx
import mlflow
from mlflow.tracking import MlflowClient
from pathlib import Path
import tempfile
from collections import namedtuple
mlflow.set_tracking_uri(mlflow_tracking_uri)
client = MlflowClient()
start_time = time.time()
# Download content
with httpx.Client(timeout=120.0) as http_client:
response = http_client.get(source_url)
content = response.content
# Detect type if auto
if source_type == "auto":
if source_url.endswith(".pdf"):
source_type = "pdf"
elif source_url.endswith(".docx"):
source_type = "docx"
elif source_url.endswith(".html") or source_url.endswith(".htm"):
source_type = "html"
else:
source_type = "text"
# Extract text based on type
if source_type == "pdf":
from PyPDF2 import PdfReader
import io
reader = PdfReader(io.BytesIO(content))
text = "\n\n".join([page.extract_text() for page in reader.pages])
elif source_type == "docx":
import docx2txt
with tempfile.NamedTemporaryFile(suffix=".docx", delete=False) as f:
f.write(content)
text = docx2txt.process(f.name)
elif source_type == "html":
from bs4 import BeautifulSoup
soup = BeautifulSoup(content, "html.parser")
text = soup.get_text(separator="\n")
else:
text = content.decode("utf-8", errors="ignore")
extraction_time = time.time() - start_time
# Log to MLflow
client.log_param(run_id, "source_url", source_url[:500])
client.log_param(run_id, "source_type", source_type)
client.log_metric(run_id, "extraction_time_seconds", extraction_time)
client.log_metric(run_id, "source_size_bytes", len(content))
client.log_metric(run_id, "extracted_char_count", len(text))
ExtractionResult = namedtuple('ExtractionResult', ['text', 'char_count', 'source_type'])
return ExtractionResult(text, len(text), source_type)
@dsl.component(
base_image=MLFLOW_IMAGE,
packages_to_install=MLFLOW_PACKAGES
)
def chunk_text_with_tracking(
text: str,
run_id: str,
chunk_size: int = 500,
overlap: int = 50,
mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80",
) -> NamedTuple('ChunkResult', [('chunks', list), ('chunk_count', int)]):
"""Split text into chunks with MLflow tracking."""
import time
import tiktoken
import mlflow
from mlflow.tracking import MlflowClient
from collections import namedtuple
mlflow.set_tracking_uri(mlflow_tracking_uri)
client = MlflowClient()
start_time = time.time()
enc = tiktoken.get_encoding("cl100k_base")
tokens = enc.encode(text)
chunks = []
start = 0
while start < len(tokens):
end = min(start + chunk_size, len(tokens))
chunk_tokens = tokens[start:end]
chunk_text = enc.decode(chunk_tokens)
chunks.append({
"text": chunk_text,
"start_token": start,
"end_token": end,
"token_count": len(chunk_tokens),
})
start += chunk_size - overlap
chunking_time = time.time() - start_time
# Log to MLflow
client.log_param(run_id, "chunk_size", chunk_size)
client.log_param(run_id, "chunk_overlap", overlap)
client.log_metric(run_id, "chunking_time_seconds", chunking_time)
client.log_metric(run_id, "total_tokens", len(tokens))
client.log_metric(run_id, "chunks_created", len(chunks))
client.log_metric(run_id, "avg_tokens_per_chunk", len(tokens) / len(chunks) if chunks else 0)
ChunkResult = namedtuple('ChunkResult', ['chunks', 'chunk_count'])
return ChunkResult(chunks, len(chunks))
@dsl.component(
base_image=MLFLOW_IMAGE,
packages_to_install=MLFLOW_PACKAGES
)
def generate_embeddings_with_tracking(
chunks: list,
run_id: str,
embeddings_url: str = "http://embeddings-predictor.ai-ml.svc.cluster.local",
embeddings_model: str = "bge-small-en-v1.5",
batch_size: int = 32,
mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80",
) -> NamedTuple('EmbeddingResult', [('embedded_chunks', list), ('embedding_dim', int)]):
"""Generate embeddings with MLflow tracking."""
import time
import httpx
import mlflow
from mlflow.tracking import MlflowClient
from collections import namedtuple
mlflow.set_tracking_uri(mlflow_tracking_uri)
client = MlflowClient()
start_time = time.time()
embedded_chunks = []
texts = [c["text"] for c in chunks]
embedding_dim = 0
with httpx.Client(timeout=300.0) as http_client:
for i in range(0, len(texts), batch_size):
batch = texts[i:i + batch_size]
response = http_client.post(
f"{embeddings_url}/embeddings",
json={"input": batch, "model": embeddings_model}
)
result = response.json()
for j, embedding_data in enumerate(result["data"]):
chunk = chunks[i + j].copy()
chunk["embedding"] = embedding_data["embedding"]
embedded_chunks.append(chunk)
if embedding_dim == 0:
embedding_dim = len(embedding_data["embedding"])
embedding_time = time.time() - start_time
# Log to MLflow
client.log_param(run_id, "embeddings_model", embeddings_model)
client.log_param(run_id, "embedding_batch_size", batch_size)
client.log_metric(run_id, "embedding_time_seconds", embedding_time)
client.log_metric(run_id, "embedding_dimension", embedding_dim)
client.log_metric(run_id, "embeddings_per_second", len(chunks) / embedding_time if embedding_time > 0 else 0)
EmbeddingResult = namedtuple('EmbeddingResult', ['embedded_chunks', 'embedding_dim'])
return EmbeddingResult(embedded_chunks, embedding_dim)
@dsl.component(
base_image=MLFLOW_IMAGE,
packages_to_install=MLFLOW_PACKAGES
)
def upsert_to_milvus_with_tracking(
chunks: list,
run_id: str,
collection_name: str,
source_name: str,
milvus_host: str = "milvus.ai-ml.svc.cluster.local",
mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80",
) -> NamedTuple('UpsertResult', [('inserted_count', int), ('collection_name', str)]):
"""Upsert to Milvus with MLflow tracking."""
import time
import mlflow
from mlflow.tracking import MlflowClient
from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType, utility
from collections import namedtuple
mlflow.set_tracking_uri(mlflow_tracking_uri)
client = MlflowClient()
start_time = time.time()
connections.connect(host=milvus_host, port=19530)
# Create collection if needed
if not utility.has_collection(collection_name):
fields = [
FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=8192),
FieldSchema(name="source", dtype=DataType.VARCHAR, max_length=512),
FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=len(chunks[0]["embedding"]))
]
schema = CollectionSchema(fields, description="Document embeddings")
Collection(name=collection_name, schema=schema)
collection = Collection(collection_name)
# Prepare and insert data
data = [
[chunk["text"] for chunk in chunks],
[source_name for _ in chunks],
[chunk["embedding"] for chunk in chunks]
]
result = collection.insert(data)
collection.flush()
# Create index if needed
if not collection.has_index():
collection.create_index(
field_name="embedding",
index_params={
"metric_type": "COSINE",
"index_type": "HNSW",
"params": {"M": 16, "efConstruction": 256}
}
)
upsert_time = time.time() - start_time
inserted_count = len(chunks)
# Log to MLflow
client.log_param(run_id, "milvus_collection", collection_name)
client.log_param(run_id, "source_name", source_name[:500])
client.log_metric(run_id, "upsert_time_seconds", upsert_time)
client.log_metric(run_id, "documents_inserted", inserted_count)
client.log_metric(run_id, "inserts_per_second", inserted_count / upsert_time if upsert_time > 0 else 0)
connections.disconnect("default")
UpsertResult = namedtuple('UpsertResult', ['inserted_count', 'collection_name'])
return UpsertResult(inserted_count, collection_name)
@dsl.component(
base_image=MLFLOW_IMAGE,
packages_to_install=["mlflow>=2.10.0", "psycopg2-binary"]
)
def finalize_mlflow_run(
run_id: str,
total_chunks: int,
status: str = "FINISHED",
mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80",
) -> str:
"""Finalize the MLflow run with summary metrics."""
import mlflow
from mlflow.tracking import MlflowClient
from mlflow.entities import RunStatus
mlflow.set_tracking_uri(mlflow_tracking_uri)
client = MlflowClient()
# Set final status
status_map = {
"FINISHED": RunStatus.FINISHED,
"FAILED": RunStatus.FAILED,
}
run_status = status_map.get(status.upper(), RunStatus.FINISHED)
# Log summary
client.set_tag(run_id, "pipeline.status", status)
client.log_metric(run_id, "final_chunk_count", total_chunks)
client.set_terminated(run_id, status=run_status)
return run_id
@dsl.pipeline(
name="mlflow-document-ingestion",
description="Document ingestion with MLflow experiment tracking"
)
def document_ingestion_mlflow_pipeline(
source_url: str,
collection_name: str = "knowledge_base",
source_name: str = "",
chunk_size: int = 500,
chunk_overlap: int = 50,
experiment_name: str = "document-ingestion",
run_name: str = "",
mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80",
):
"""
Document Ingestion Pipeline with MLflow Tracking
Ingests documents into Milvus while tracking all metrics
and parameters in MLflow for experiment comparison.
"""
import time
# Generate run name if not provided
effective_run_name = run_name or f"ingestion-{int(time.time())}"
# Start MLflow run
mlflow_run = start_mlflow_run(
experiment_name=experiment_name,
run_name=effective_run_name,
mlflow_tracking_uri=mlflow_tracking_uri,
pipeline_params={
"source_url": source_url,
"collection_name": collection_name,
"chunk_size": chunk_size,
"chunk_overlap": chunk_overlap,
},
)
# Extract text
extraction = extract_text_with_tracking(
source_url=source_url,
run_id=mlflow_run.outputs["run_id"],
mlflow_tracking_uri=mlflow_tracking_uri,
)
extraction.set_caching_options(enable_caching=True)
# Chunk text
chunking = chunk_text_with_tracking(
text=extraction.outputs["text"],
run_id=mlflow_run.outputs["run_id"],
chunk_size=chunk_size,
overlap=chunk_overlap,
mlflow_tracking_uri=mlflow_tracking_uri,
)
chunking.set_caching_options(enable_caching=True)
# Generate embeddings
embedding = generate_embeddings_with_tracking(
chunks=chunking.outputs["chunks"],
run_id=mlflow_run.outputs["run_id"],
mlflow_tracking_uri=mlflow_tracking_uri,
)
embedding.set_caching_options(enable_caching=True)
# Upsert to Milvus
upsert = upsert_to_milvus_with_tracking(
chunks=embedding.outputs["embedded_chunks"],
run_id=mlflow_run.outputs["run_id"],
collection_name=collection_name,
source_name=source_name or source_url,
mlflow_tracking_uri=mlflow_tracking_uri,
)
# Finalize run
finalize_mlflow_run(
run_id=mlflow_run.outputs["run_id"],
total_chunks=upsert.outputs["inserted_count"],
mlflow_tracking_uri=mlflow_tracking_uri,
)
if __name__ == "__main__":
compiler.Compiler().compile(
document_ingestion_mlflow_pipeline,
"document_ingestion_mlflow_pipeline.yaml"
)
print("Compiled: document_ingestion_mlflow_pipeline.yaml")