feat: Add Kubeflow Pipeline definitions
- voice_pipeline: STT → RAG → LLM → TTS - document_ingestion_pipeline: Extract → Chunk → Embed → Milvus - document_ingestion_mlflow_pipeline: With MLflow tracking - evaluation_pipeline: Model benchmarking - kfp-sync-job: K8s job to sync pipelines
This commit is contained in:
208
evaluation_pipeline.py
Normal file
208
evaluation_pipeline.py
Normal file
@@ -0,0 +1,208 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Model Evaluation Pipeline - Kubeflow Pipelines SDK
|
||||
|
||||
Evaluates fine-tuned models against benchmarks.
|
||||
Integrates with Argo Workflows for automated model deployment.
|
||||
|
||||
Usage:
|
||||
pip install kfp==2.12.1
|
||||
python evaluation_pipeline.py
|
||||
"""
|
||||
|
||||
from kfp import dsl
|
||||
from kfp import compiler
|
||||
from typing import Dict
|
||||
|
||||
|
||||
@dsl.component(
|
||||
base_image="python:3.13-slim",
|
||||
packages_to_install=["httpx"]
|
||||
)
|
||||
def load_eval_dataset(
|
||||
dataset_name: str = "mmlu",
|
||||
subset: str = "test",
|
||||
limit: int = 100
|
||||
) -> list:
|
||||
"""Load evaluation dataset samples."""
|
||||
import httpx
|
||||
import json
|
||||
|
||||
# For now, use a simple test set
|
||||
# In production, this would load from HuggingFace or S3
|
||||
test_samples = [
|
||||
{
|
||||
"question": "What is the capital of France?",
|
||||
"choices": ["London", "Berlin", "Paris", "Madrid"],
|
||||
"answer": "C"
|
||||
},
|
||||
{
|
||||
"question": "Which planet is known as the Red Planet?",
|
||||
"choices": ["Venus", "Mars", "Jupiter", "Saturn"],
|
||||
"answer": "B"
|
||||
},
|
||||
{
|
||||
"question": "What is 2 + 2?",
|
||||
"choices": ["3", "4", "5", "6"],
|
||||
"answer": "B"
|
||||
}
|
||||
]
|
||||
|
||||
return test_samples[:limit]
|
||||
|
||||
|
||||
@dsl.component(
|
||||
base_image="python:3.13-slim",
|
||||
packages_to_install=["httpx"]
|
||||
)
|
||||
def run_inference(
|
||||
samples: list,
|
||||
model_endpoint: str,
|
||||
model_name: str = "default"
|
||||
) -> list:
|
||||
"""Run inference on evaluation samples."""
|
||||
import httpx
|
||||
|
||||
results = []
|
||||
|
||||
with httpx.Client(timeout=120.0) as client:
|
||||
for sample in samples:
|
||||
prompt = f"""Answer the following multiple choice question.
|
||||
|
||||
Question: {sample['question']}
|
||||
Choices:
|
||||
A) {sample['choices'][0]}
|
||||
B) {sample['choices'][1]}
|
||||
C) {sample['choices'][2]}
|
||||
D) {sample['choices'][3]}
|
||||
|
||||
Answer with just the letter (A, B, C, or D):"""
|
||||
|
||||
response = client.post(
|
||||
f"{model_endpoint}/v1/chat/completions",
|
||||
json={
|
||||
"model": model_name,
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"max_tokens": 10,
|
||||
"temperature": 0
|
||||
}
|
||||
)
|
||||
|
||||
result = response.json()
|
||||
answer = result["choices"][0]["message"]["content"].strip().upper()
|
||||
|
||||
results.append({
|
||||
"question": sample["question"],
|
||||
"expected": sample["answer"],
|
||||
"predicted": answer[0] if answer else "X",
|
||||
"correct": answer.startswith(sample["answer"])
|
||||
})
|
||||
|
||||
return results
|
||||
|
||||
|
||||
@dsl.component(
|
||||
base_image="python:3.13-slim"
|
||||
)
|
||||
def calculate_metrics(
|
||||
results: list
|
||||
) -> dict:
|
||||
"""Calculate evaluation metrics."""
|
||||
correct = sum(1 for r in results if r["correct"])
|
||||
total = len(results)
|
||||
|
||||
accuracy = correct / total if total > 0 else 0
|
||||
|
||||
return {
|
||||
"accuracy": accuracy,
|
||||
"correct": correct,
|
||||
"total": total,
|
||||
"pass": accuracy >= 0.7 # 70% threshold
|
||||
}
|
||||
|
||||
|
||||
@dsl.component(
|
||||
base_image="python:3.13-slim",
|
||||
packages_to_install=["httpx"]
|
||||
)
|
||||
def publish_results(
|
||||
metrics: dict,
|
||||
model_name: str,
|
||||
nats_url: str = "nats://nats.ai-ml.svc.cluster.local:4222"
|
||||
) -> str:
|
||||
"""Publish evaluation results to NATS."""
|
||||
import subprocess
|
||||
import sys
|
||||
subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "nats-py"])
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import nats
|
||||
|
||||
async def publish():
|
||||
nc = await nats.connect(nats_url)
|
||||
await nc.publish(
|
||||
f"ai.evaluation.results.{model_name}",
|
||||
json.dumps({
|
||||
"model": model_name,
|
||||
"metrics": metrics,
|
||||
"status": "passed" if metrics["pass"] else "failed"
|
||||
}).encode()
|
||||
)
|
||||
await nc.close()
|
||||
|
||||
asyncio.run(publish())
|
||||
return "published"
|
||||
|
||||
|
||||
@dsl.pipeline(
|
||||
name="model-evaluation-pipeline",
|
||||
description="Evaluate model performance on benchmarks"
|
||||
)
|
||||
def model_evaluation_pipeline(
|
||||
model_endpoint: str = "http://llm-draft.ai-ml.svc.cluster.local:8000",
|
||||
model_name: str = "default",
|
||||
dataset_name: str = "mmlu",
|
||||
sample_limit: int = 100
|
||||
):
|
||||
"""
|
||||
Model Evaluation Pipeline
|
||||
|
||||
Args:
|
||||
model_endpoint: URL of the model inference endpoint
|
||||
model_name: Name of the model being evaluated
|
||||
dataset_name: Evaluation dataset to use
|
||||
sample_limit: Maximum samples to evaluate
|
||||
"""
|
||||
|
||||
# Load dataset
|
||||
load_task = load_eval_dataset(
|
||||
dataset_name=dataset_name,
|
||||
limit=sample_limit
|
||||
)
|
||||
load_task.set_caching_options(enable_caching=True)
|
||||
|
||||
# Run inference
|
||||
inference_task = run_inference(
|
||||
samples=load_task.output,
|
||||
model_endpoint=model_endpoint,
|
||||
model_name=model_name
|
||||
)
|
||||
inference_task.set_caching_options(enable_caching=False)
|
||||
|
||||
# Calculate metrics
|
||||
metrics_task = calculate_metrics(results=inference_task.output)
|
||||
|
||||
# Publish results
|
||||
publish_task = publish_results(
|
||||
metrics=metrics_task.output,
|
||||
model_name=model_name
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
compiler.Compiler().compile(
|
||||
model_evaluation_pipeline,
|
||||
"evaluation_pipeline.yaml"
|
||||
)
|
||||
print("Compiled: evaluation_pipeline.yaml")
|
||||
Reference in New Issue
Block a user