feat: add vLLM tuning pipeline + recompile voice pipelines with MLflow
New: - vllm_tuning_pipeline.py: A/B benchmark different vLLM configs, logs latency/TPS/TTFT to MLflow (vllm-tuning experiment) - vllm_tuning_pipeline.yaml: compiled KFP YAML Updated: - voice_pipeline.py: per-step NamedTuple outputs with latency tracking, new log_pipeline_metrics MLflow component - voice_pipeline.yaml, tts_pipeline.yaml, rag_pipeline.yaml: recompiled
This commit is contained in:
454
vllm_tuning_pipeline.py
Normal file
454
vllm_tuning_pipeline.py
Normal file
@@ -0,0 +1,454 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
vLLM Tuning Evaluation Pipeline - Kubeflow Pipelines SDK
|
||||
|
||||
Runs inference benchmarks with different vLLM configurations and logs
|
||||
results to MLflow so you can compare APC, chunked prefill, speculative
|
||||
decoding, and GPU memory utilisation settings side-by-side.
|
||||
|
||||
Usage:
|
||||
pip install kfp==2.12.1
|
||||
python vllm_tuning_pipeline.py
|
||||
# Upload vllm_tuning_pipeline.yaml to Kubeflow Pipelines UI
|
||||
"""
|
||||
|
||||
from kfp import dsl
|
||||
from kfp import compiler
|
||||
from typing import NamedTuple
|
||||
|
||||
|
||||
MLFLOW_IMAGE = "python:3.13-slim"
|
||||
MLFLOW_PACKAGES = ["mlflow>=2.10.0", "boto3", "psycopg2-binary"]
|
||||
BENCH_PACKAGES = ["httpx"]
|
||||
|
||||
|
||||
# ---- MLflow components ----
|
||||
|
||||
|
||||
@dsl.component(base_image=MLFLOW_IMAGE, packages_to_install=MLFLOW_PACKAGES)
|
||||
def create_tuning_run(
|
||||
experiment_name: str,
|
||||
run_name: str,
|
||||
tuning_params: dict,
|
||||
mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80",
|
||||
) -> NamedTuple("RunInfo", [("run_id", str), ("experiment_id", str)]):
|
||||
"""Create an MLflow run for a vLLM tuning experiment."""
|
||||
import os
|
||||
import mlflow
|
||||
from mlflow.tracking import MlflowClient
|
||||
from collections import namedtuple
|
||||
|
||||
mlflow.set_tracking_uri(mlflow_tracking_uri)
|
||||
client = MlflowClient()
|
||||
|
||||
exp = client.get_experiment_by_name(experiment_name)
|
||||
experiment_id = (
|
||||
exp.experiment_id
|
||||
if exp
|
||||
else client.create_experiment(
|
||||
name=experiment_name,
|
||||
artifact_location=f"/mlflow/artifacts/{experiment_name}",
|
||||
)
|
||||
)
|
||||
|
||||
tags = {
|
||||
"pipeline.type": "vllm-tuning",
|
||||
"kfp.run_id": os.environ.get("KFP_RUN_ID", "unknown"),
|
||||
}
|
||||
|
||||
run = mlflow.start_run(
|
||||
experiment_id=experiment_id, run_name=run_name, tags=tags
|
||||
)
|
||||
# Log every tuning param
|
||||
for key, value in tuning_params.items():
|
||||
mlflow.log_param(f"vllm.{key}", value)
|
||||
run_id = run.info.run_id
|
||||
mlflow.end_run()
|
||||
|
||||
RunInfo = namedtuple("RunInfo", ["run_id", "experiment_id"])
|
||||
return RunInfo(run_id, experiment_id)
|
||||
|
||||
|
||||
@dsl.component(base_image=MLFLOW_IMAGE, packages_to_install=MLFLOW_PACKAGES)
|
||||
def log_benchmark_results(
|
||||
run_id: str,
|
||||
metrics: dict,
|
||||
mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80",
|
||||
) -> str:
|
||||
"""Log benchmark metrics to MLflow and close the run."""
|
||||
import json
|
||||
import tempfile
|
||||
import mlflow
|
||||
from mlflow.tracking import MlflowClient
|
||||
from pathlib import Path
|
||||
|
||||
mlflow.set_tracking_uri(mlflow_tracking_uri)
|
||||
client = MlflowClient()
|
||||
|
||||
for key, value in metrics.items():
|
||||
client.log_metric(run_id, key, float(value))
|
||||
|
||||
# Save full results as artifact
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
path = Path(tmpdir) / "benchmark_results.json"
|
||||
path.write_text(json.dumps(metrics, indent=2))
|
||||
client.log_artifact(run_id, str(path))
|
||||
|
||||
client.set_terminated(run_id, status="FINISHED")
|
||||
return run_id
|
||||
|
||||
|
||||
# ---- Benchmark components ----
|
||||
|
||||
|
||||
@dsl.component(
|
||||
base_image="python:3.13-slim",
|
||||
packages_to_install=BENCH_PACKAGES,
|
||||
)
|
||||
def build_prompt_suite() -> list:
|
||||
"""Return a list of test prompts spanning short, medium, and long inputs."""
|
||||
return [
|
||||
{
|
||||
"id": "short-1",
|
||||
"category": "short",
|
||||
"messages": [
|
||||
{"role": "user", "content": "What is the capital of France?"}
|
||||
],
|
||||
"max_tokens": 64,
|
||||
},
|
||||
{
|
||||
"id": "short-2",
|
||||
"category": "short",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Explain quantum computing in one sentence."}
|
||||
],
|
||||
"max_tokens": 64,
|
||||
},
|
||||
{
|
||||
"id": "medium-1",
|
||||
"category": "medium",
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful AI assistant running on a homelab.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
"Compare and contrast supervised and unsupervised "
|
||||
"machine learning. Give examples of each and explain "
|
||||
"when you would choose one over the other."
|
||||
),
|
||||
},
|
||||
],
|
||||
"max_tokens": 512,
|
||||
},
|
||||
{
|
||||
"id": "medium-2",
|
||||
"category": "medium",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
"Write a Python function that implements a binary search "
|
||||
"tree with insert, search, and delete operations. Include "
|
||||
"docstrings and type hints."
|
||||
),
|
||||
},
|
||||
],
|
||||
"max_tokens": 1024,
|
||||
},
|
||||
{
|
||||
"id": "long-1",
|
||||
"category": "long",
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a technical writer for a Kubernetes homelab blog.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
"Write a detailed tutorial on setting up a multi-node "
|
||||
"Kubernetes cluster with Talos Linux, covering: "
|
||||
"1) Hardware requirements and network topology, "
|
||||
"2) Talos machine config generation, "
|
||||
"3) Control plane bootstrapping, "
|
||||
"4) Worker node joining, "
|
||||
"5) CNI setup with Cilium, "
|
||||
"6) Storage with Rook-Ceph, "
|
||||
"7) GitOps with Flux CD. "
|
||||
"Include YAML examples for each step."
|
||||
),
|
||||
},
|
||||
],
|
||||
"max_tokens": 2048,
|
||||
},
|
||||
{
|
||||
"id": "repeat-prefix-1",
|
||||
"category": "prefix-cache-test",
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful AI assistant running on a homelab.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
"Compare and contrast supervised and unsupervised "
|
||||
"machine learning. Now focus specifically on "
|
||||
"reinforcement learning and how it differs."
|
||||
),
|
||||
},
|
||||
],
|
||||
"max_tokens": 512,
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
@dsl.component(
|
||||
base_image="python:3.13-slim",
|
||||
packages_to_install=BENCH_PACKAGES,
|
||||
)
|
||||
def run_benchmark(
|
||||
prompts: list,
|
||||
llm_endpoint: str,
|
||||
model_name: str,
|
||||
num_warmup: int = 2,
|
||||
num_iterations: int = 3,
|
||||
) -> dict:
|
||||
"""
|
||||
Run all prompts through the LLM endpoint and collect timing metrics.
|
||||
|
||||
Returns aggregate metrics: p50/p95/mean latency, tokens/sec, TTFT.
|
||||
"""
|
||||
import time
|
||||
import statistics
|
||||
import httpx
|
||||
|
||||
all_latencies: list[float] = []
|
||||
all_tps: list[float] = []
|
||||
all_ttft: list[float] = []
|
||||
per_category: dict[str, list[float]] = {}
|
||||
|
||||
with httpx.Client(timeout=300.0) as client:
|
||||
# Warmup
|
||||
for _ in range(num_warmup):
|
||||
try:
|
||||
client.post(
|
||||
f"{llm_endpoint}/v1/chat/completions",
|
||||
json={
|
||||
"model": model_name,
|
||||
"messages": [{"role": "user", "content": "Hi"}],
|
||||
"max_tokens": 8,
|
||||
"temperature": 0,
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Benchmark
|
||||
for iteration in range(num_iterations):
|
||||
for prompt in prompts:
|
||||
category = prompt.get("category", "unknown")
|
||||
payload = {
|
||||
"model": model_name,
|
||||
"messages": prompt["messages"],
|
||||
"max_tokens": prompt.get("max_tokens", 256),
|
||||
"temperature": 0,
|
||||
"stream": True,
|
||||
}
|
||||
|
||||
try:
|
||||
t_start = time.perf_counter()
|
||||
first_token_time = None
|
||||
|
||||
with client.stream(
|
||||
"POST",
|
||||
f"{llm_endpoint}/v1/chat/completions",
|
||||
json=payload,
|
||||
) as resp:
|
||||
resp.raise_for_status()
|
||||
completion_tokens = 0
|
||||
for line in resp.iter_lines():
|
||||
if not line.startswith("data: "):
|
||||
continue
|
||||
chunk = line[6:]
|
||||
if chunk == "[DONE]":
|
||||
break
|
||||
if first_token_time is None:
|
||||
first_token_time = time.perf_counter()
|
||||
completion_tokens += 1
|
||||
|
||||
t_end = time.perf_counter()
|
||||
latency = t_end - t_start
|
||||
ttft = (
|
||||
(first_token_time - t_start)
|
||||
if first_token_time
|
||||
else latency
|
||||
)
|
||||
tps = (
|
||||
completion_tokens / latency if latency > 0 else 0
|
||||
)
|
||||
|
||||
all_latencies.append(latency)
|
||||
all_tps.append(tps)
|
||||
all_ttft.append(ttft)
|
||||
per_category.setdefault(category, []).append(latency)
|
||||
|
||||
except Exception as exc:
|
||||
# Record failure but keep going
|
||||
all_latencies.append(-1)
|
||||
all_tps.append(0)
|
||||
all_ttft.append(-1)
|
||||
|
||||
# Compute aggregates
|
||||
valid_latencies = [l for l in all_latencies if l > 0]
|
||||
valid_tps = [t for t in all_tps if t > 0]
|
||||
valid_ttft = [t for t in all_ttft if t > 0]
|
||||
|
||||
def safe_stat(values, func):
|
||||
return func(values) if values else 0
|
||||
|
||||
metrics = {
|
||||
"total_requests": len(all_latencies),
|
||||
"successful_requests": len(valid_latencies),
|
||||
"failed_requests": len(all_latencies) - len(valid_latencies),
|
||||
# Latency
|
||||
"latency_mean_s": safe_stat(valid_latencies, statistics.mean),
|
||||
"latency_p50_s": safe_stat(
|
||||
valid_latencies,
|
||||
lambda v: statistics.median(v),
|
||||
),
|
||||
"latency_p95_s": safe_stat(
|
||||
valid_latencies,
|
||||
lambda v: sorted(v)[int(len(v) * 0.95)] if v else 0,
|
||||
),
|
||||
# Throughput
|
||||
"tokens_per_second_mean": safe_stat(valid_tps, statistics.mean),
|
||||
"tokens_per_second_p50": safe_stat(
|
||||
valid_tps, lambda v: statistics.median(v)
|
||||
),
|
||||
# Time to first token
|
||||
"ttft_mean_s": safe_stat(valid_ttft, statistics.mean),
|
||||
"ttft_p50_s": safe_stat(valid_ttft, lambda v: statistics.median(v)),
|
||||
"ttft_p95_s": safe_stat(
|
||||
valid_ttft,
|
||||
lambda v: sorted(v)[int(len(v) * 0.95)] if v else 0,
|
||||
),
|
||||
}
|
||||
|
||||
# Per-category latency
|
||||
for cat, lats in per_category.items():
|
||||
valid = [l for l in lats if l > 0]
|
||||
if valid:
|
||||
metrics[f"latency_mean_{cat}_s"] = statistics.mean(valid)
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
# ---- Pipeline ----
|
||||
|
||||
|
||||
@dsl.pipeline(
|
||||
name="vllm-tuning-evaluation",
|
||||
description=(
|
||||
"Benchmark vLLM with different tuning configurations. "
|
||||
"Logs latency, TPS, and TTFT to MLflow for A/B comparison."
|
||||
),
|
||||
)
|
||||
def vllm_tuning_pipeline(
|
||||
llm_endpoint: str = "http://ai-inference-serve-svc.ai-ml.svc.cluster.local:8000/llm",
|
||||
model_name: str = "hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4",
|
||||
# Tuning knobs (match env vars in rayservice.yaml)
|
||||
enable_prefix_caching: str = "true",
|
||||
enable_chunked_prefill: str = "true",
|
||||
num_speculative_tokens: str = "3",
|
||||
ngram_prompt_lookup_max: str = "4",
|
||||
gpu_memory_utilization: str = "0.90",
|
||||
# Benchmark config
|
||||
num_warmup: int = 2,
|
||||
num_iterations: int = 3,
|
||||
run_label: str = "baseline",
|
||||
):
|
||||
"""
|
||||
vLLM Tuning Evaluation Pipeline
|
||||
|
||||
Run this multiple times with different tuning params, then compare
|
||||
runs in the MLflow "vllm-tuning" experiment.
|
||||
|
||||
Args:
|
||||
llm_endpoint: vLLM inference endpoint URL
|
||||
model_name: HF model identifier
|
||||
enable_prefix_caching: "true" or "false"
|
||||
enable_chunked_prefill: "true" or "false"
|
||||
num_speculative_tokens: number of speculative tokens (0 = off)
|
||||
ngram_prompt_lookup_max: ngram window for spec decode (0 = off)
|
||||
gpu_memory_utilization: 0.0 - 1.0
|
||||
num_warmup: warmup requests before timing
|
||||
num_iterations: how many times to repeat the prompt suite
|
||||
run_label: human-readable label (e.g. "apc-on-spec3")
|
||||
"""
|
||||
|
||||
tuning_params = {
|
||||
"enable_prefix_caching": enable_prefix_caching,
|
||||
"enable_chunked_prefill": enable_chunked_prefill,
|
||||
"num_speculative_tokens": num_speculative_tokens,
|
||||
"ngram_prompt_lookup_max": ngram_prompt_lookup_max,
|
||||
"gpu_memory_utilization": gpu_memory_utilization,
|
||||
"model_name": model_name,
|
||||
"llm_endpoint": llm_endpoint,
|
||||
"num_warmup": str(num_warmup),
|
||||
"num_iterations": str(num_iterations),
|
||||
}
|
||||
|
||||
# 1. Create MLflow run
|
||||
mlflow_run = create_tuning_run(
|
||||
experiment_name="vllm-tuning",
|
||||
run_name=f"vllm-{run_label}",
|
||||
tuning_params=tuning_params,
|
||||
)
|
||||
|
||||
# 2. Build prompt suite
|
||||
prompts_task = build_prompt_suite()
|
||||
prompts_task.set_caching_options(enable_caching=True)
|
||||
|
||||
# 3. Run benchmark
|
||||
bench_task = run_benchmark(
|
||||
prompts=prompts_task.output,
|
||||
llm_endpoint=llm_endpoint,
|
||||
model_name=model_name,
|
||||
num_warmup=num_warmup,
|
||||
num_iterations=num_iterations,
|
||||
)
|
||||
bench_task.set_caching_options(enable_caching=False)
|
||||
|
||||
# 4. Log results to MLflow
|
||||
log_task = log_benchmark_results(
|
||||
run_id=mlflow_run.outputs["run_id"],
|
||||
metrics=bench_task.output,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
compiler.Compiler().compile(
|
||||
vllm_tuning_pipeline,
|
||||
"vllm_tuning_pipeline.yaml",
|
||||
)
|
||||
print("Compiled: vllm_tuning_pipeline.yaml")
|
||||
print()
|
||||
print("Example runs to compare configurations:")
|
||||
print(" # Baseline (current config)")
|
||||
print(" kfp run submit vllm_tuning_pipeline.yaml --run-label=baseline")
|
||||
print()
|
||||
print(" # APC disabled")
|
||||
print(" kfp run submit vllm_tuning_pipeline.yaml \\")
|
||||
print(" --enable-prefix-caching=false --run-label=no-apc")
|
||||
print()
|
||||
print(" # No speculative decoding")
|
||||
print(" kfp run submit vllm_tuning_pipeline.yaml \\")
|
||||
print(" --num-speculative-tokens=0 --run-label=no-spec")
|
||||
print()
|
||||
print(" # Aggressive spec decode")
|
||||
print(" kfp run submit vllm_tuning_pipeline.yaml \\")
|
||||
print(" --num-speculative-tokens=5 --ngram-prompt-lookup-max=6 --run-label=spec5-ngram6")
|
||||
Reference in New Issue
Block a user