258 lines
7.3 KiB
Python
258 lines
7.3 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Document Ingestion Pipeline - Kubeflow Pipelines SDK
|
|
|
|
Ingests documents into Milvus vector database with embeddings.
|
|
Can be triggered from Argo Workflows via the kfp-trigger template.
|
|
|
|
Usage:
|
|
pip install kfp==2.12.1
|
|
python document_ingestion_pipeline.py
|
|
# Upload document_ingestion.yaml to Kubeflow Pipelines UI
|
|
"""
|
|
|
|
from kfp import dsl
|
|
from kfp import compiler
|
|
from typing import List
|
|
|
|
|
|
@dsl.component(
|
|
base_image="python:3.13-slim",
|
|
packages_to_install=["httpx", "beautifulsoup4", "pypdf2", "docx2txt"]
|
|
)
|
|
def extract_text(
|
|
source_url: str,
|
|
source_type: str = "auto"
|
|
) -> str:
|
|
"""Extract text from various document formats."""
|
|
import httpx
|
|
from pathlib import Path
|
|
import tempfile
|
|
|
|
with httpx.Client(timeout=120.0) as client:
|
|
response = client.get(source_url)
|
|
content = response.content
|
|
|
|
# Detect type if auto
|
|
if source_type == "auto":
|
|
if source_url.endswith(".pdf"):
|
|
source_type = "pdf"
|
|
elif source_url.endswith(".docx"):
|
|
source_type = "docx"
|
|
elif source_url.endswith(".html") or source_url.endswith(".htm"):
|
|
source_type = "html"
|
|
else:
|
|
source_type = "text"
|
|
|
|
if source_type == "pdf":
|
|
from PyPDF2 import PdfReader
|
|
import io
|
|
reader = PdfReader(io.BytesIO(content))
|
|
text = "\n\n".join([page.extract_text() for page in reader.pages])
|
|
elif source_type == "docx":
|
|
import docx2txt
|
|
with tempfile.NamedTemporaryFile(suffix=".docx", delete=False) as f:
|
|
f.write(content)
|
|
text = docx2txt.process(f.name)
|
|
elif source_type == "html":
|
|
from bs4 import BeautifulSoup
|
|
soup = BeautifulSoup(content, "html.parser")
|
|
text = soup.get_text(separator="\n")
|
|
else:
|
|
text = content.decode("utf-8", errors="ignore")
|
|
|
|
return text
|
|
|
|
|
|
@dsl.component(
|
|
base_image="python:3.13-slim",
|
|
packages_to_install=["tiktoken"]
|
|
)
|
|
def chunk_text(
|
|
text: str,
|
|
chunk_size: int = 500,
|
|
overlap: int = 50
|
|
) -> list:
|
|
"""Split text into overlapping chunks."""
|
|
import tiktoken
|
|
|
|
enc = tiktoken.get_encoding("cl100k_base")
|
|
tokens = enc.encode(text)
|
|
|
|
chunks = []
|
|
start = 0
|
|
while start < len(tokens):
|
|
end = min(start + chunk_size, len(tokens))
|
|
chunk_tokens = tokens[start:end]
|
|
chunk_text = enc.decode(chunk_tokens)
|
|
chunks.append({
|
|
"text": chunk_text,
|
|
"start_token": start,
|
|
"end_token": end
|
|
})
|
|
start += chunk_size - overlap
|
|
|
|
return chunks
|
|
|
|
|
|
@dsl.component(
|
|
base_image="python:3.13-slim",
|
|
packages_to_install=["httpx"]
|
|
)
|
|
def generate_embeddings_batch(
|
|
chunks: list,
|
|
embeddings_url: str = "http://ai-inference-serve-svc.ai-ml.svc.cluster.local:8000/embeddings",
|
|
batch_size: int = 32
|
|
) -> list:
|
|
"""Generate embeddings for all chunks."""
|
|
import httpx
|
|
|
|
embedded_chunks = []
|
|
texts = [c["text"] for c in chunks]
|
|
|
|
with httpx.Client(timeout=300.0) as client:
|
|
for i in range(0, len(texts), batch_size):
|
|
batch = texts[i:i + batch_size]
|
|
response = client.post(
|
|
f"{embeddings_url}/embeddings",
|
|
json={"input": batch, "model": "bge-small-en-v1.5"}
|
|
)
|
|
result = response.json()
|
|
|
|
for j, embedding_data in enumerate(result["data"]):
|
|
chunk = chunks[i + j].copy()
|
|
chunk["embedding"] = embedding_data["embedding"]
|
|
embedded_chunks.append(chunk)
|
|
|
|
return embedded_chunks
|
|
|
|
|
|
@dsl.component(
|
|
base_image="python:3.13-slim",
|
|
packages_to_install=["pymilvus"]
|
|
)
|
|
def upsert_to_milvus(
|
|
chunks: list,
|
|
collection_name: str,
|
|
source_name: str,
|
|
milvus_host: str = "milvus.ai-ml.svc.cluster.local"
|
|
) -> int:
|
|
"""Upsert embeddings to Milvus collection."""
|
|
from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType, utility
|
|
|
|
connections.connect(host=milvus_host, port=19530)
|
|
|
|
# Create collection if it doesn't exist
|
|
if not utility.has_collection(collection_name):
|
|
fields = [
|
|
FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
|
|
FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=8192),
|
|
FieldSchema(name="source", dtype=DataType.VARCHAR, max_length=512),
|
|
FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=1024)
|
|
]
|
|
schema = CollectionSchema(fields, description="Document embeddings")
|
|
Collection(name=collection_name, schema=schema)
|
|
|
|
collection = Collection(collection_name)
|
|
|
|
# Prepare data
|
|
data = [
|
|
[chunk["text"] for chunk in chunks],
|
|
[source_name for _ in chunks],
|
|
[chunk["embedding"] for chunk in chunks]
|
|
]
|
|
|
|
# Insert
|
|
result = collection.insert(data)
|
|
collection.flush()
|
|
|
|
# Create index if needed
|
|
if not collection.has_index():
|
|
collection.create_index(
|
|
field_name="embedding",
|
|
index_params={
|
|
"metric_type": "COSINE",
|
|
"index_type": "HNSW",
|
|
"params": {"M": 16, "efConstruction": 256}
|
|
}
|
|
)
|
|
|
|
return len(chunks)
|
|
|
|
|
|
@dsl.pipeline(
|
|
name="document-ingestion-pipeline",
|
|
description="Ingest documents into Milvus: Extract -> Chunk -> Embed -> Store"
|
|
)
|
|
def document_ingestion_pipeline(
|
|
source_url: str,
|
|
collection_name: str = "knowledge_base",
|
|
source_name: str = "",
|
|
chunk_size: int = 500,
|
|
chunk_overlap: int = 50
|
|
):
|
|
"""
|
|
Document Ingestion Pipeline
|
|
|
|
Args:
|
|
source_url: URL to the document (PDF, DOCX, HTML, or plain text)
|
|
collection_name: Milvus collection to store embeddings
|
|
source_name: Human-readable name for the source
|
|
chunk_size: Token count per chunk
|
|
chunk_overlap: Overlap between chunks
|
|
"""
|
|
|
|
# Step 1: Extract text from document
|
|
extract_task = extract_text(source_url=source_url)
|
|
extract_task.set_caching_options(enable_caching=True)
|
|
|
|
# Step 2: Chunk the text
|
|
chunk_task = chunk_text(
|
|
text=extract_task.output,
|
|
chunk_size=chunk_size,
|
|
overlap=chunk_overlap
|
|
)
|
|
chunk_task.set_caching_options(enable_caching=True)
|
|
|
|
# Step 3: Generate embeddings
|
|
embed_task = generate_embeddings_batch(chunks=chunk_task.output)
|
|
embed_task.set_caching_options(enable_caching=True)
|
|
|
|
# Step 4: Store in Milvus
|
|
store_task = upsert_to_milvus(
|
|
chunks=embed_task.output,
|
|
collection_name=collection_name,
|
|
source_name=source_name if source_name else source_url
|
|
)
|
|
|
|
|
|
@dsl.pipeline(
|
|
name="batch-document-ingestion",
|
|
description="Ingest multiple documents in parallel"
|
|
)
|
|
def batch_document_ingestion_pipeline(
|
|
source_urls: List[str],
|
|
collection_name: str = "knowledge_base"
|
|
):
|
|
"""
|
|
Batch Document Ingestion
|
|
|
|
Args:
|
|
source_urls: List of document URLs to ingest
|
|
collection_name: Target Milvus collection
|
|
"""
|
|
with dsl.ParallelFor(source_urls) as url:
|
|
document_ingestion_pipeline(
|
|
source_url=url,
|
|
collection_name=collection_name
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# Primary pipeline - filename must match Python file for sync
|
|
compiler.Compiler().compile(
|
|
document_ingestion_pipeline,
|
|
"document_ingestion_pipeline.yaml"
|
|
)
|
|
print("Compiled: document_ingestion_pipeline.yaml")
|