feat: Add Kubeflow Pipeline definitions
- voice_pipeline: STT → RAG → LLM → TTS - document_ingestion_pipeline: Extract → Chunk → Embed → Milvus - document_ingestion_mlflow_pipeline: With MLflow tracking - evaluation_pipeline: Model benchmarking - kfp-sync-job: K8s job to sync pipelines
This commit is contained in:
466
document_ingestion_mlflow_pipeline.py
Normal file
466
document_ingestion_mlflow_pipeline.py
Normal file
@@ -0,0 +1,466 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
MLflow-Integrated Document Ingestion Pipeline
|
||||
|
||||
Enhanced version of document_ingestion_pipeline.py with full MLflow
|
||||
experiment tracking for metrics, parameters, and artifacts.
|
||||
|
||||
Usage:
|
||||
pip install kfp==2.12.1 mlflow>=2.10.0
|
||||
python document_ingestion_mlflow_pipeline.py
|
||||
# Upload to Kubeflow Pipelines UI
|
||||
"""
|
||||
|
||||
from kfp import dsl
|
||||
from kfp import compiler
|
||||
from typing import NamedTuple
|
||||
|
||||
|
||||
# MLflow component configuration
|
||||
MLFLOW_IMAGE = "python:3.13-slim"
|
||||
MLFLOW_PACKAGES = [
|
||||
"mlflow>=2.10.0",
|
||||
"psycopg2-binary>=2.9.0",
|
||||
"httpx",
|
||||
"beautifulsoup4",
|
||||
"pypdf2",
|
||||
"docx2txt",
|
||||
"tiktoken",
|
||||
"pymilvus",
|
||||
]
|
||||
|
||||
|
||||
@dsl.component(
|
||||
base_image=MLFLOW_IMAGE,
|
||||
packages_to_install=["mlflow>=2.10.0", "psycopg2-binary"]
|
||||
)
|
||||
def start_mlflow_run(
|
||||
experiment_name: str,
|
||||
run_name: str,
|
||||
mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80",
|
||||
pipeline_params: dict = None,
|
||||
) -> NamedTuple('RunInfo', [('run_id', str), ('experiment_id', str)]):
|
||||
"""Start an MLflow run and log initial parameters."""
|
||||
import os
|
||||
import mlflow
|
||||
from mlflow.tracking import MlflowClient
|
||||
from collections import namedtuple
|
||||
|
||||
mlflow.set_tracking_uri(mlflow_tracking_uri)
|
||||
client = MlflowClient()
|
||||
|
||||
# Get or create experiment
|
||||
experiment = client.get_experiment_by_name(experiment_name)
|
||||
if experiment is None:
|
||||
experiment_id = client.create_experiment(
|
||||
name=experiment_name,
|
||||
artifact_location=f"/mlflow/artifacts/{experiment_name}"
|
||||
)
|
||||
else:
|
||||
experiment_id = experiment.experiment_id
|
||||
|
||||
# Start run with tags
|
||||
tags = {
|
||||
"pipeline.type": "document-ingestion",
|
||||
"pipeline.framework": "kubeflow",
|
||||
"kfp.run_id": os.environ.get("KFP_RUN_ID", "unknown"),
|
||||
}
|
||||
|
||||
run = mlflow.start_run(
|
||||
experiment_id=experiment_id,
|
||||
run_name=run_name,
|
||||
tags=tags,
|
||||
)
|
||||
|
||||
# Log pipeline parameters
|
||||
if pipeline_params:
|
||||
for key, value in pipeline_params.items():
|
||||
mlflow.log_param(key, str(value)[:500])
|
||||
|
||||
run_id = run.info.run_id
|
||||
mlflow.end_run()
|
||||
|
||||
RunInfo = namedtuple('RunInfo', ['run_id', 'experiment_id'])
|
||||
return RunInfo(run_id, experiment_id)
|
||||
|
||||
|
||||
@dsl.component(
|
||||
base_image=MLFLOW_IMAGE,
|
||||
packages_to_install=MLFLOW_PACKAGES
|
||||
)
|
||||
def extract_text_with_tracking(
|
||||
source_url: str,
|
||||
run_id: str,
|
||||
source_type: str = "auto",
|
||||
mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80",
|
||||
) -> NamedTuple('ExtractionResult', [('text', str), ('char_count', int), ('source_type', str)]):
|
||||
"""Extract text from documents with MLflow tracking."""
|
||||
import time
|
||||
import httpx
|
||||
import mlflow
|
||||
from mlflow.tracking import MlflowClient
|
||||
from pathlib import Path
|
||||
import tempfile
|
||||
from collections import namedtuple
|
||||
|
||||
mlflow.set_tracking_uri(mlflow_tracking_uri)
|
||||
client = MlflowClient()
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Download content
|
||||
with httpx.Client(timeout=120.0) as http_client:
|
||||
response = http_client.get(source_url)
|
||||
content = response.content
|
||||
|
||||
# Detect type if auto
|
||||
if source_type == "auto":
|
||||
if source_url.endswith(".pdf"):
|
||||
source_type = "pdf"
|
||||
elif source_url.endswith(".docx"):
|
||||
source_type = "docx"
|
||||
elif source_url.endswith(".html") or source_url.endswith(".htm"):
|
||||
source_type = "html"
|
||||
else:
|
||||
source_type = "text"
|
||||
|
||||
# Extract text based on type
|
||||
if source_type == "pdf":
|
||||
from PyPDF2 import PdfReader
|
||||
import io
|
||||
reader = PdfReader(io.BytesIO(content))
|
||||
text = "\n\n".join([page.extract_text() for page in reader.pages])
|
||||
elif source_type == "docx":
|
||||
import docx2txt
|
||||
with tempfile.NamedTemporaryFile(suffix=".docx", delete=False) as f:
|
||||
f.write(content)
|
||||
text = docx2txt.process(f.name)
|
||||
elif source_type == "html":
|
||||
from bs4 import BeautifulSoup
|
||||
soup = BeautifulSoup(content, "html.parser")
|
||||
text = soup.get_text(separator="\n")
|
||||
else:
|
||||
text = content.decode("utf-8", errors="ignore")
|
||||
|
||||
extraction_time = time.time() - start_time
|
||||
|
||||
# Log to MLflow
|
||||
client.log_param(run_id, "source_url", source_url[:500])
|
||||
client.log_param(run_id, "source_type", source_type)
|
||||
client.log_metric(run_id, "extraction_time_seconds", extraction_time)
|
||||
client.log_metric(run_id, "source_size_bytes", len(content))
|
||||
client.log_metric(run_id, "extracted_char_count", len(text))
|
||||
|
||||
ExtractionResult = namedtuple('ExtractionResult', ['text', 'char_count', 'source_type'])
|
||||
return ExtractionResult(text, len(text), source_type)
|
||||
|
||||
|
||||
@dsl.component(
|
||||
base_image=MLFLOW_IMAGE,
|
||||
packages_to_install=MLFLOW_PACKAGES
|
||||
)
|
||||
def chunk_text_with_tracking(
|
||||
text: str,
|
||||
run_id: str,
|
||||
chunk_size: int = 500,
|
||||
overlap: int = 50,
|
||||
mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80",
|
||||
) -> NamedTuple('ChunkResult', [('chunks', list), ('chunk_count', int)]):
|
||||
"""Split text into chunks with MLflow tracking."""
|
||||
import time
|
||||
import tiktoken
|
||||
import mlflow
|
||||
from mlflow.tracking import MlflowClient
|
||||
from collections import namedtuple
|
||||
|
||||
mlflow.set_tracking_uri(mlflow_tracking_uri)
|
||||
client = MlflowClient()
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
enc = tiktoken.get_encoding("cl100k_base")
|
||||
tokens = enc.encode(text)
|
||||
|
||||
chunks = []
|
||||
start = 0
|
||||
while start < len(tokens):
|
||||
end = min(start + chunk_size, len(tokens))
|
||||
chunk_tokens = tokens[start:end]
|
||||
chunk_text = enc.decode(chunk_tokens)
|
||||
chunks.append({
|
||||
"text": chunk_text,
|
||||
"start_token": start,
|
||||
"end_token": end,
|
||||
"token_count": len(chunk_tokens),
|
||||
})
|
||||
start += chunk_size - overlap
|
||||
|
||||
chunking_time = time.time() - start_time
|
||||
|
||||
# Log to MLflow
|
||||
client.log_param(run_id, "chunk_size", chunk_size)
|
||||
client.log_param(run_id, "chunk_overlap", overlap)
|
||||
client.log_metric(run_id, "chunking_time_seconds", chunking_time)
|
||||
client.log_metric(run_id, "total_tokens", len(tokens))
|
||||
client.log_metric(run_id, "chunks_created", len(chunks))
|
||||
client.log_metric(run_id, "avg_tokens_per_chunk", len(tokens) / len(chunks) if chunks else 0)
|
||||
|
||||
ChunkResult = namedtuple('ChunkResult', ['chunks', 'chunk_count'])
|
||||
return ChunkResult(chunks, len(chunks))
|
||||
|
||||
|
||||
@dsl.component(
|
||||
base_image=MLFLOW_IMAGE,
|
||||
packages_to_install=MLFLOW_PACKAGES
|
||||
)
|
||||
def generate_embeddings_with_tracking(
|
||||
chunks: list,
|
||||
run_id: str,
|
||||
embeddings_url: str = "http://embeddings-predictor.ai-ml.svc.cluster.local",
|
||||
embeddings_model: str = "bge-small-en-v1.5",
|
||||
batch_size: int = 32,
|
||||
mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80",
|
||||
) -> NamedTuple('EmbeddingResult', [('embedded_chunks', list), ('embedding_dim', int)]):
|
||||
"""Generate embeddings with MLflow tracking."""
|
||||
import time
|
||||
import httpx
|
||||
import mlflow
|
||||
from mlflow.tracking import MlflowClient
|
||||
from collections import namedtuple
|
||||
|
||||
mlflow.set_tracking_uri(mlflow_tracking_uri)
|
||||
client = MlflowClient()
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
embedded_chunks = []
|
||||
texts = [c["text"] for c in chunks]
|
||||
embedding_dim = 0
|
||||
|
||||
with httpx.Client(timeout=300.0) as http_client:
|
||||
for i in range(0, len(texts), batch_size):
|
||||
batch = texts[i:i + batch_size]
|
||||
response = http_client.post(
|
||||
f"{embeddings_url}/embeddings",
|
||||
json={"input": batch, "model": embeddings_model}
|
||||
)
|
||||
result = response.json()
|
||||
|
||||
for j, embedding_data in enumerate(result["data"]):
|
||||
chunk = chunks[i + j].copy()
|
||||
chunk["embedding"] = embedding_data["embedding"]
|
||||
embedded_chunks.append(chunk)
|
||||
|
||||
if embedding_dim == 0:
|
||||
embedding_dim = len(embedding_data["embedding"])
|
||||
|
||||
embedding_time = time.time() - start_time
|
||||
|
||||
# Log to MLflow
|
||||
client.log_param(run_id, "embeddings_model", embeddings_model)
|
||||
client.log_param(run_id, "embedding_batch_size", batch_size)
|
||||
client.log_metric(run_id, "embedding_time_seconds", embedding_time)
|
||||
client.log_metric(run_id, "embedding_dimension", embedding_dim)
|
||||
client.log_metric(run_id, "embeddings_per_second", len(chunks) / embedding_time if embedding_time > 0 else 0)
|
||||
|
||||
EmbeddingResult = namedtuple('EmbeddingResult', ['embedded_chunks', 'embedding_dim'])
|
||||
return EmbeddingResult(embedded_chunks, embedding_dim)
|
||||
|
||||
|
||||
@dsl.component(
|
||||
base_image=MLFLOW_IMAGE,
|
||||
packages_to_install=MLFLOW_PACKAGES
|
||||
)
|
||||
def upsert_to_milvus_with_tracking(
|
||||
chunks: list,
|
||||
run_id: str,
|
||||
collection_name: str,
|
||||
source_name: str,
|
||||
milvus_host: str = "milvus.ai-ml.svc.cluster.local",
|
||||
mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80",
|
||||
) -> NamedTuple('UpsertResult', [('inserted_count', int), ('collection_name', str)]):
|
||||
"""Upsert to Milvus with MLflow tracking."""
|
||||
import time
|
||||
import mlflow
|
||||
from mlflow.tracking import MlflowClient
|
||||
from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType, utility
|
||||
from collections import namedtuple
|
||||
|
||||
mlflow.set_tracking_uri(mlflow_tracking_uri)
|
||||
client = MlflowClient()
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
connections.connect(host=milvus_host, port=19530)
|
||||
|
||||
# Create collection if needed
|
||||
if not utility.has_collection(collection_name):
|
||||
fields = [
|
||||
FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
|
||||
FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=8192),
|
||||
FieldSchema(name="source", dtype=DataType.VARCHAR, max_length=512),
|
||||
FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=len(chunks[0]["embedding"]))
|
||||
]
|
||||
schema = CollectionSchema(fields, description="Document embeddings")
|
||||
Collection(name=collection_name, schema=schema)
|
||||
|
||||
collection = Collection(collection_name)
|
||||
|
||||
# Prepare and insert data
|
||||
data = [
|
||||
[chunk["text"] for chunk in chunks],
|
||||
[source_name for _ in chunks],
|
||||
[chunk["embedding"] for chunk in chunks]
|
||||
]
|
||||
|
||||
result = collection.insert(data)
|
||||
collection.flush()
|
||||
|
||||
# Create index if needed
|
||||
if not collection.has_index():
|
||||
collection.create_index(
|
||||
field_name="embedding",
|
||||
index_params={
|
||||
"metric_type": "COSINE",
|
||||
"index_type": "HNSW",
|
||||
"params": {"M": 16, "efConstruction": 256}
|
||||
}
|
||||
)
|
||||
|
||||
upsert_time = time.time() - start_time
|
||||
inserted_count = len(chunks)
|
||||
|
||||
# Log to MLflow
|
||||
client.log_param(run_id, "milvus_collection", collection_name)
|
||||
client.log_param(run_id, "source_name", source_name[:500])
|
||||
client.log_metric(run_id, "upsert_time_seconds", upsert_time)
|
||||
client.log_metric(run_id, "documents_inserted", inserted_count)
|
||||
client.log_metric(run_id, "inserts_per_second", inserted_count / upsert_time if upsert_time > 0 else 0)
|
||||
|
||||
connections.disconnect("default")
|
||||
|
||||
UpsertResult = namedtuple('UpsertResult', ['inserted_count', 'collection_name'])
|
||||
return UpsertResult(inserted_count, collection_name)
|
||||
|
||||
|
||||
@dsl.component(
|
||||
base_image=MLFLOW_IMAGE,
|
||||
packages_to_install=["mlflow>=2.10.0", "psycopg2-binary"]
|
||||
)
|
||||
def finalize_mlflow_run(
|
||||
run_id: str,
|
||||
total_chunks: int,
|
||||
status: str = "FINISHED",
|
||||
mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80",
|
||||
) -> str:
|
||||
"""Finalize the MLflow run with summary metrics."""
|
||||
import mlflow
|
||||
from mlflow.tracking import MlflowClient
|
||||
from mlflow.entities import RunStatus
|
||||
|
||||
mlflow.set_tracking_uri(mlflow_tracking_uri)
|
||||
client = MlflowClient()
|
||||
|
||||
# Set final status
|
||||
status_map = {
|
||||
"FINISHED": RunStatus.FINISHED,
|
||||
"FAILED": RunStatus.FAILED,
|
||||
}
|
||||
run_status = status_map.get(status.upper(), RunStatus.FINISHED)
|
||||
|
||||
# Log summary
|
||||
client.set_tag(run_id, "pipeline.status", status)
|
||||
client.log_metric(run_id, "final_chunk_count", total_chunks)
|
||||
|
||||
client.set_terminated(run_id, status=run_status)
|
||||
|
||||
return run_id
|
||||
|
||||
|
||||
@dsl.pipeline(
|
||||
name="mlflow-document-ingestion",
|
||||
description="Document ingestion with MLflow experiment tracking"
|
||||
)
|
||||
def document_ingestion_mlflow_pipeline(
|
||||
source_url: str,
|
||||
collection_name: str = "knowledge_base",
|
||||
source_name: str = "",
|
||||
chunk_size: int = 500,
|
||||
chunk_overlap: int = 50,
|
||||
experiment_name: str = "document-ingestion",
|
||||
run_name: str = "",
|
||||
mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80",
|
||||
):
|
||||
"""
|
||||
Document Ingestion Pipeline with MLflow Tracking
|
||||
|
||||
Ingests documents into Milvus while tracking all metrics
|
||||
and parameters in MLflow for experiment comparison.
|
||||
"""
|
||||
import time
|
||||
|
||||
# Generate run name if not provided
|
||||
effective_run_name = run_name or f"ingestion-{int(time.time())}"
|
||||
|
||||
# Start MLflow run
|
||||
mlflow_run = start_mlflow_run(
|
||||
experiment_name=experiment_name,
|
||||
run_name=effective_run_name,
|
||||
mlflow_tracking_uri=mlflow_tracking_uri,
|
||||
pipeline_params={
|
||||
"source_url": source_url,
|
||||
"collection_name": collection_name,
|
||||
"chunk_size": chunk_size,
|
||||
"chunk_overlap": chunk_overlap,
|
||||
},
|
||||
)
|
||||
|
||||
# Extract text
|
||||
extraction = extract_text_with_tracking(
|
||||
source_url=source_url,
|
||||
run_id=mlflow_run.outputs["run_id"],
|
||||
mlflow_tracking_uri=mlflow_tracking_uri,
|
||||
)
|
||||
extraction.set_caching_options(enable_caching=True)
|
||||
|
||||
# Chunk text
|
||||
chunking = chunk_text_with_tracking(
|
||||
text=extraction.outputs["text"],
|
||||
run_id=mlflow_run.outputs["run_id"],
|
||||
chunk_size=chunk_size,
|
||||
overlap=chunk_overlap,
|
||||
mlflow_tracking_uri=mlflow_tracking_uri,
|
||||
)
|
||||
chunking.set_caching_options(enable_caching=True)
|
||||
|
||||
# Generate embeddings
|
||||
embedding = generate_embeddings_with_tracking(
|
||||
chunks=chunking.outputs["chunks"],
|
||||
run_id=mlflow_run.outputs["run_id"],
|
||||
mlflow_tracking_uri=mlflow_tracking_uri,
|
||||
)
|
||||
embedding.set_caching_options(enable_caching=True)
|
||||
|
||||
# Upsert to Milvus
|
||||
upsert = upsert_to_milvus_with_tracking(
|
||||
chunks=embedding.outputs["embedded_chunks"],
|
||||
run_id=mlflow_run.outputs["run_id"],
|
||||
collection_name=collection_name,
|
||||
source_name=source_name or source_url,
|
||||
mlflow_tracking_uri=mlflow_tracking_uri,
|
||||
)
|
||||
|
||||
# Finalize run
|
||||
finalize_mlflow_run(
|
||||
run_id=mlflow_run.outputs["run_id"],
|
||||
total_chunks=upsert.outputs["inserted_count"],
|
||||
mlflow_tracking_uri=mlflow_tracking_uri,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
compiler.Compiler().compile(
|
||||
document_ingestion_mlflow_pipeline,
|
||||
"document_ingestion_mlflow_pipeline.yaml"
|
||||
)
|
||||
print("Compiled: document_ingestion_mlflow_pipeline.yaml")
|
||||
Reference in New Issue
Block a user