Files
kubeflow/document_ingestion_pipeline.py
2026-02-02 07:12:05 -05:00

258 lines
7.3 KiB
Python

#!/usr/bin/env python3
"""
Document Ingestion Pipeline - Kubeflow Pipelines SDK
Ingests documents into Milvus vector database with embeddings.
Can be triggered from Argo Workflows via the kfp-trigger template.
Usage:
pip install kfp==2.12.1
python document_ingestion_pipeline.py
# Upload document_ingestion.yaml to Kubeflow Pipelines UI
"""
from kfp import dsl
from kfp import compiler
from typing import List
@dsl.component(
base_image="python:3.13-slim",
packages_to_install=["httpx", "beautifulsoup4", "pypdf2", "docx2txt"]
)
def extract_text(
source_url: str,
source_type: str = "auto"
) -> str:
"""Extract text from various document formats."""
import httpx
from pathlib import Path
import tempfile
with httpx.Client(timeout=120.0) as client:
response = client.get(source_url)
content = response.content
# Detect type if auto
if source_type == "auto":
if source_url.endswith(".pdf"):
source_type = "pdf"
elif source_url.endswith(".docx"):
source_type = "docx"
elif source_url.endswith(".html") or source_url.endswith(".htm"):
source_type = "html"
else:
source_type = "text"
if source_type == "pdf":
from PyPDF2 import PdfReader
import io
reader = PdfReader(io.BytesIO(content))
text = "\n\n".join([page.extract_text() for page in reader.pages])
elif source_type == "docx":
import docx2txt
with tempfile.NamedTemporaryFile(suffix=".docx", delete=False) as f:
f.write(content)
text = docx2txt.process(f.name)
elif source_type == "html":
from bs4 import BeautifulSoup
soup = BeautifulSoup(content, "html.parser")
text = soup.get_text(separator="\n")
else:
text = content.decode("utf-8", errors="ignore")
return text
@dsl.component(
base_image="python:3.13-slim",
packages_to_install=["tiktoken"]
)
def chunk_text(
text: str,
chunk_size: int = 500,
overlap: int = 50
) -> list:
"""Split text into overlapping chunks."""
import tiktoken
enc = tiktoken.get_encoding("cl100k_base")
tokens = enc.encode(text)
chunks = []
start = 0
while start < len(tokens):
end = min(start + chunk_size, len(tokens))
chunk_tokens = tokens[start:end]
chunk_text = enc.decode(chunk_tokens)
chunks.append({
"text": chunk_text,
"start_token": start,
"end_token": end
})
start += chunk_size - overlap
return chunks
@dsl.component(
base_image="python:3.13-slim",
packages_to_install=["httpx"]
)
def generate_embeddings_batch(
chunks: list,
embeddings_url: str = "http://ai-inference-serve-svc.ai-ml.svc.cluster.local:8000/embeddings",
batch_size: int = 32
) -> list:
"""Generate embeddings for all chunks."""
import httpx
embedded_chunks = []
texts = [c["text"] for c in chunks]
with httpx.Client(timeout=300.0) as client:
for i in range(0, len(texts), batch_size):
batch = texts[i:i + batch_size]
response = client.post(
f"{embeddings_url}/embeddings",
json={"input": batch, "model": "bge-small-en-v1.5"}
)
result = response.json()
for j, embedding_data in enumerate(result["data"]):
chunk = chunks[i + j].copy()
chunk["embedding"] = embedding_data["embedding"]
embedded_chunks.append(chunk)
return embedded_chunks
@dsl.component(
base_image="python:3.13-slim",
packages_to_install=["pymilvus"]
)
def upsert_to_milvus(
chunks: list,
collection_name: str,
source_name: str,
milvus_host: str = "milvus.ai-ml.svc.cluster.local"
) -> int:
"""Upsert embeddings to Milvus collection."""
from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType, utility
connections.connect(host=milvus_host, port=19530)
# Create collection if it doesn't exist
if not utility.has_collection(collection_name):
fields = [
FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=8192),
FieldSchema(name="source", dtype=DataType.VARCHAR, max_length=512),
FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=1024)
]
schema = CollectionSchema(fields, description="Document embeddings")
Collection(name=collection_name, schema=schema)
collection = Collection(collection_name)
# Prepare data
data = [
[chunk["text"] for chunk in chunks],
[source_name for _ in chunks],
[chunk["embedding"] for chunk in chunks]
]
# Insert
result = collection.insert(data)
collection.flush()
# Create index if needed
if not collection.has_index():
collection.create_index(
field_name="embedding",
index_params={
"metric_type": "COSINE",
"index_type": "HNSW",
"params": {"M": 16, "efConstruction": 256}
}
)
return len(chunks)
@dsl.pipeline(
name="document-ingestion-pipeline",
description="Ingest documents into Milvus: Extract -> Chunk -> Embed -> Store"
)
def document_ingestion_pipeline(
source_url: str,
collection_name: str = "knowledge_base",
source_name: str = "",
chunk_size: int = 500,
chunk_overlap: int = 50
):
"""
Document Ingestion Pipeline
Args:
source_url: URL to the document (PDF, DOCX, HTML, or plain text)
collection_name: Milvus collection to store embeddings
source_name: Human-readable name for the source
chunk_size: Token count per chunk
chunk_overlap: Overlap between chunks
"""
# Step 1: Extract text from document
extract_task = extract_text(source_url=source_url)
extract_task.set_caching_options(enable_caching=True)
# Step 2: Chunk the text
chunk_task = chunk_text(
text=extract_task.output,
chunk_size=chunk_size,
overlap=chunk_overlap
)
chunk_task.set_caching_options(enable_caching=True)
# Step 3: Generate embeddings
embed_task = generate_embeddings_batch(chunks=chunk_task.output)
embed_task.set_caching_options(enable_caching=True)
# Step 4: Store in Milvus
store_task = upsert_to_milvus(
chunks=embed_task.output,
collection_name=collection_name,
source_name=source_name if source_name else source_url
)
@dsl.pipeline(
name="batch-document-ingestion",
description="Ingest multiple documents in parallel"
)
def batch_document_ingestion_pipeline(
source_urls: List[str],
collection_name: str = "knowledge_base"
):
"""
Batch Document Ingestion
Args:
source_urls: List of document URLs to ingest
collection_name: Target Milvus collection
"""
with dsl.ParallelFor(source_urls) as url:
document_ingestion_pipeline(
source_url=url,
collection_name=collection_name
)
if __name__ == "__main__":
# Primary pipeline - filename must match Python file for sync
compiler.Compiler().compile(
document_ingestion_pipeline,
"document_ingestion_pipeline.yaml"
)
print("Compiled: document_ingestion_pipeline.yaml")