Files
kubeflow/vllm_tuning_pipeline.py
Billy D. bc4b230dd9 feat: add vLLM tuning pipeline + recompile voice pipelines with MLflow
New:
- vllm_tuning_pipeline.py: A/B benchmark different vLLM configs,
  logs latency/TPS/TTFT to MLflow (vllm-tuning experiment)
- vllm_tuning_pipeline.yaml: compiled KFP YAML

Updated:
- voice_pipeline.py: per-step NamedTuple outputs with latency tracking,
  new log_pipeline_metrics MLflow component
- voice_pipeline.yaml, tts_pipeline.yaml, rag_pipeline.yaml: recompiled
2026-02-13 08:24:11 -05:00

455 lines
15 KiB
Python

#!/usr/bin/env python3
"""
vLLM Tuning Evaluation Pipeline - Kubeflow Pipelines SDK
Runs inference benchmarks with different vLLM configurations and logs
results to MLflow so you can compare APC, chunked prefill, speculative
decoding, and GPU memory utilisation settings side-by-side.
Usage:
pip install kfp==2.12.1
python vllm_tuning_pipeline.py
# Upload vllm_tuning_pipeline.yaml to Kubeflow Pipelines UI
"""
from kfp import dsl
from kfp import compiler
from typing import NamedTuple
MLFLOW_IMAGE = "python:3.13-slim"
MLFLOW_PACKAGES = ["mlflow>=2.10.0", "boto3", "psycopg2-binary"]
BENCH_PACKAGES = ["httpx"]
# ---- MLflow components ----
@dsl.component(base_image=MLFLOW_IMAGE, packages_to_install=MLFLOW_PACKAGES)
def create_tuning_run(
experiment_name: str,
run_name: str,
tuning_params: dict,
mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80",
) -> NamedTuple("RunInfo", [("run_id", str), ("experiment_id", str)]):
"""Create an MLflow run for a vLLM tuning experiment."""
import os
import mlflow
from mlflow.tracking import MlflowClient
from collections import namedtuple
mlflow.set_tracking_uri(mlflow_tracking_uri)
client = MlflowClient()
exp = client.get_experiment_by_name(experiment_name)
experiment_id = (
exp.experiment_id
if exp
else client.create_experiment(
name=experiment_name,
artifact_location=f"/mlflow/artifacts/{experiment_name}",
)
)
tags = {
"pipeline.type": "vllm-tuning",
"kfp.run_id": os.environ.get("KFP_RUN_ID", "unknown"),
}
run = mlflow.start_run(
experiment_id=experiment_id, run_name=run_name, tags=tags
)
# Log every tuning param
for key, value in tuning_params.items():
mlflow.log_param(f"vllm.{key}", value)
run_id = run.info.run_id
mlflow.end_run()
RunInfo = namedtuple("RunInfo", ["run_id", "experiment_id"])
return RunInfo(run_id, experiment_id)
@dsl.component(base_image=MLFLOW_IMAGE, packages_to_install=MLFLOW_PACKAGES)
def log_benchmark_results(
run_id: str,
metrics: dict,
mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80",
) -> str:
"""Log benchmark metrics to MLflow and close the run."""
import json
import tempfile
import mlflow
from mlflow.tracking import MlflowClient
from pathlib import Path
mlflow.set_tracking_uri(mlflow_tracking_uri)
client = MlflowClient()
for key, value in metrics.items():
client.log_metric(run_id, key, float(value))
# Save full results as artifact
with tempfile.TemporaryDirectory() as tmpdir:
path = Path(tmpdir) / "benchmark_results.json"
path.write_text(json.dumps(metrics, indent=2))
client.log_artifact(run_id, str(path))
client.set_terminated(run_id, status="FINISHED")
return run_id
# ---- Benchmark components ----
@dsl.component(
base_image="python:3.13-slim",
packages_to_install=BENCH_PACKAGES,
)
def build_prompt_suite() -> list:
"""Return a list of test prompts spanning short, medium, and long inputs."""
return [
{
"id": "short-1",
"category": "short",
"messages": [
{"role": "user", "content": "What is the capital of France?"}
],
"max_tokens": 64,
},
{
"id": "short-2",
"category": "short",
"messages": [
{"role": "user", "content": "Explain quantum computing in one sentence."}
],
"max_tokens": 64,
},
{
"id": "medium-1",
"category": "medium",
"messages": [
{
"role": "system",
"content": "You are a helpful AI assistant running on a homelab.",
},
{
"role": "user",
"content": (
"Compare and contrast supervised and unsupervised "
"machine learning. Give examples of each and explain "
"when you would choose one over the other."
),
},
],
"max_tokens": 512,
},
{
"id": "medium-2",
"category": "medium",
"messages": [
{
"role": "user",
"content": (
"Write a Python function that implements a binary search "
"tree with insert, search, and delete operations. Include "
"docstrings and type hints."
),
},
],
"max_tokens": 1024,
},
{
"id": "long-1",
"category": "long",
"messages": [
{
"role": "system",
"content": "You are a technical writer for a Kubernetes homelab blog.",
},
{
"role": "user",
"content": (
"Write a detailed tutorial on setting up a multi-node "
"Kubernetes cluster with Talos Linux, covering: "
"1) Hardware requirements and network topology, "
"2) Talos machine config generation, "
"3) Control plane bootstrapping, "
"4) Worker node joining, "
"5) CNI setup with Cilium, "
"6) Storage with Rook-Ceph, "
"7) GitOps with Flux CD. "
"Include YAML examples for each step."
),
},
],
"max_tokens": 2048,
},
{
"id": "repeat-prefix-1",
"category": "prefix-cache-test",
"messages": [
{
"role": "system",
"content": "You are a helpful AI assistant running on a homelab.",
},
{
"role": "user",
"content": (
"Compare and contrast supervised and unsupervised "
"machine learning. Now focus specifically on "
"reinforcement learning and how it differs."
),
},
],
"max_tokens": 512,
},
]
@dsl.component(
base_image="python:3.13-slim",
packages_to_install=BENCH_PACKAGES,
)
def run_benchmark(
prompts: list,
llm_endpoint: str,
model_name: str,
num_warmup: int = 2,
num_iterations: int = 3,
) -> dict:
"""
Run all prompts through the LLM endpoint and collect timing metrics.
Returns aggregate metrics: p50/p95/mean latency, tokens/sec, TTFT.
"""
import time
import statistics
import httpx
all_latencies: list[float] = []
all_tps: list[float] = []
all_ttft: list[float] = []
per_category: dict[str, list[float]] = {}
with httpx.Client(timeout=300.0) as client:
# Warmup
for _ in range(num_warmup):
try:
client.post(
f"{llm_endpoint}/v1/chat/completions",
json={
"model": model_name,
"messages": [{"role": "user", "content": "Hi"}],
"max_tokens": 8,
"temperature": 0,
},
)
except Exception:
pass
# Benchmark
for iteration in range(num_iterations):
for prompt in prompts:
category = prompt.get("category", "unknown")
payload = {
"model": model_name,
"messages": prompt["messages"],
"max_tokens": prompt.get("max_tokens", 256),
"temperature": 0,
"stream": True,
}
try:
t_start = time.perf_counter()
first_token_time = None
with client.stream(
"POST",
f"{llm_endpoint}/v1/chat/completions",
json=payload,
) as resp:
resp.raise_for_status()
completion_tokens = 0
for line in resp.iter_lines():
if not line.startswith("data: "):
continue
chunk = line[6:]
if chunk == "[DONE]":
break
if first_token_time is None:
first_token_time = time.perf_counter()
completion_tokens += 1
t_end = time.perf_counter()
latency = t_end - t_start
ttft = (
(first_token_time - t_start)
if first_token_time
else latency
)
tps = (
completion_tokens / latency if latency > 0 else 0
)
all_latencies.append(latency)
all_tps.append(tps)
all_ttft.append(ttft)
per_category.setdefault(category, []).append(latency)
except Exception as exc:
# Record failure but keep going
all_latencies.append(-1)
all_tps.append(0)
all_ttft.append(-1)
# Compute aggregates
valid_latencies = [l for l in all_latencies if l > 0]
valid_tps = [t for t in all_tps if t > 0]
valid_ttft = [t for t in all_ttft if t > 0]
def safe_stat(values, func):
return func(values) if values else 0
metrics = {
"total_requests": len(all_latencies),
"successful_requests": len(valid_latencies),
"failed_requests": len(all_latencies) - len(valid_latencies),
# Latency
"latency_mean_s": safe_stat(valid_latencies, statistics.mean),
"latency_p50_s": safe_stat(
valid_latencies,
lambda v: statistics.median(v),
),
"latency_p95_s": safe_stat(
valid_latencies,
lambda v: sorted(v)[int(len(v) * 0.95)] if v else 0,
),
# Throughput
"tokens_per_second_mean": safe_stat(valid_tps, statistics.mean),
"tokens_per_second_p50": safe_stat(
valid_tps, lambda v: statistics.median(v)
),
# Time to first token
"ttft_mean_s": safe_stat(valid_ttft, statistics.mean),
"ttft_p50_s": safe_stat(valid_ttft, lambda v: statistics.median(v)),
"ttft_p95_s": safe_stat(
valid_ttft,
lambda v: sorted(v)[int(len(v) * 0.95)] if v else 0,
),
}
# Per-category latency
for cat, lats in per_category.items():
valid = [l for l in lats if l > 0]
if valid:
metrics[f"latency_mean_{cat}_s"] = statistics.mean(valid)
return metrics
# ---- Pipeline ----
@dsl.pipeline(
name="vllm-tuning-evaluation",
description=(
"Benchmark vLLM with different tuning configurations. "
"Logs latency, TPS, and TTFT to MLflow for A/B comparison."
),
)
def vllm_tuning_pipeline(
llm_endpoint: str = "http://ai-inference-serve-svc.ai-ml.svc.cluster.local:8000/llm",
model_name: str = "hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4",
# Tuning knobs (match env vars in rayservice.yaml)
enable_prefix_caching: str = "true",
enable_chunked_prefill: str = "true",
num_speculative_tokens: str = "3",
ngram_prompt_lookup_max: str = "4",
gpu_memory_utilization: str = "0.90",
# Benchmark config
num_warmup: int = 2,
num_iterations: int = 3,
run_label: str = "baseline",
):
"""
vLLM Tuning Evaluation Pipeline
Run this multiple times with different tuning params, then compare
runs in the MLflow "vllm-tuning" experiment.
Args:
llm_endpoint: vLLM inference endpoint URL
model_name: HF model identifier
enable_prefix_caching: "true" or "false"
enable_chunked_prefill: "true" or "false"
num_speculative_tokens: number of speculative tokens (0 = off)
ngram_prompt_lookup_max: ngram window for spec decode (0 = off)
gpu_memory_utilization: 0.0 - 1.0
num_warmup: warmup requests before timing
num_iterations: how many times to repeat the prompt suite
run_label: human-readable label (e.g. "apc-on-spec3")
"""
tuning_params = {
"enable_prefix_caching": enable_prefix_caching,
"enable_chunked_prefill": enable_chunked_prefill,
"num_speculative_tokens": num_speculative_tokens,
"ngram_prompt_lookup_max": ngram_prompt_lookup_max,
"gpu_memory_utilization": gpu_memory_utilization,
"model_name": model_name,
"llm_endpoint": llm_endpoint,
"num_warmup": str(num_warmup),
"num_iterations": str(num_iterations),
}
# 1. Create MLflow run
mlflow_run = create_tuning_run(
experiment_name="vllm-tuning",
run_name=f"vllm-{run_label}",
tuning_params=tuning_params,
)
# 2. Build prompt suite
prompts_task = build_prompt_suite()
prompts_task.set_caching_options(enable_caching=True)
# 3. Run benchmark
bench_task = run_benchmark(
prompts=prompts_task.output,
llm_endpoint=llm_endpoint,
model_name=model_name,
num_warmup=num_warmup,
num_iterations=num_iterations,
)
bench_task.set_caching_options(enable_caching=False)
# 4. Log results to MLflow
log_task = log_benchmark_results(
run_id=mlflow_run.outputs["run_id"],
metrics=bench_task.output,
)
if __name__ == "__main__":
compiler.Compiler().compile(
vllm_tuning_pipeline,
"vllm_tuning_pipeline.yaml",
)
print("Compiled: vllm_tuning_pipeline.yaml")
print()
print("Example runs to compare configurations:")
print(" # Baseline (current config)")
print(" kfp run submit vllm_tuning_pipeline.yaml --run-label=baseline")
print()
print(" # APC disabled")
print(" kfp run submit vllm_tuning_pipeline.yaml \\")
print(" --enable-prefix-caching=false --run-label=no-apc")
print()
print(" # No speculative decoding")
print(" kfp run submit vllm_tuning_pipeline.yaml \\")
print(" --num-speculative-tokens=0 --run-label=no-spec")
print()
print(" # Aggressive spec decode")
print(" kfp run submit vllm_tuning_pipeline.yaml \\")
print(" --num-speculative-tokens=5 --ngram-prompt-lookup-max=6 --run-label=spec5-ngram6")