Files
argo/batch-inference.yaml
Billy D. f3d7da9008 refactor: rename part-of label from llm-workflows to ai-ml-pipelines
llm-workflows repo has been phased out. Update labels to reflect
the new ai-ml-pipelines naming convention.
2026-02-02 17:41:27 -05:00

329 lines
12 KiB
YAML

# Batch Inference Workflow
# Runs LLM inference on a batch of inputs
# Triggered via NATS: ai.pipeline.trigger with pipeline="batch-inference"
---
apiVersion: argoproj.io/v1alpha1
kind: WorkflowTemplate
metadata:
name: batch-inference
namespace: ai-ml
labels:
app.kubernetes.io/name: batch-inference
app.kubernetes.io/part-of: ai-ml-pipelines
spec:
entrypoint: batch-inference
serviceAccountName: argo-workflow
arguments:
parameters:
- name: input-url
description: "URL to JSON file with inference requests"
- name: output-url
description: "URL to store results (S3 path)"
value: ""
- name: use-rag
value: "true"
description: "Whether to use RAG for context"
- name: max-tokens
value: "500"
description: "Maximum tokens per response"
- name: temperature
value: "0.7"
description: "LLM temperature"
templates:
- name: batch-inference
dag:
tasks:
- name: fetch-inputs
template: fetch-input-data
arguments:
parameters:
- name: input-url
value: "{{workflow.parameters.input-url}}"
- name: run-inference
template: inference
dependencies: [fetch-inputs]
arguments:
parameters:
- name: use-rag
value: "{{workflow.parameters.use-rag}}"
- name: max-tokens
value: "{{workflow.parameters.max-tokens}}"
- name: temperature
value: "{{workflow.parameters.temperature}}"
artifacts:
- name: inputs
from: "{{tasks.fetch-inputs.outputs.artifacts.inputs}}"
- name: upload-results
template: upload-output
dependencies: [run-inference]
when: "{{workflow.parameters.output-url}} != ''"
arguments:
parameters:
- name: output-url
value: "{{workflow.parameters.output-url}}"
artifacts:
- name: results
from: "{{tasks.run-inference.outputs.artifacts.results}}"
- name: fetch-input-data
inputs:
parameters:
- name: input-url
outputs:
artifacts:
- name: inputs
path: /tmp/inputs
container:
image: python:3.13-slim
command: [python]
args:
- -c
- |
import json
import urllib.request
from pathlib import Path
input_url = "{{inputs.parameters.input-url}}"
output_dir = Path("/tmp/inputs")
output_dir.mkdir(parents=True, exist_ok=True)
print(f"Fetching inputs from: {input_url}")
if input_url.startswith("s3://"):
import subprocess
subprocess.run(["pip", "install", "boto3", "-q"], check=True)
import boto3
s3 = boto3.client("s3")
bucket, key = input_url[5:].split("/", 1)
s3.download_file(bucket, key, str(output_dir / "inputs.json"))
elif input_url.startswith("http"):
urllib.request.urlretrieve(input_url, output_dir / "inputs.json")
else:
print(f"Unsupported URL scheme: {input_url}")
exit(1)
# Validate JSON structure
with open(output_dir / "inputs.json") as f:
data = json.load(f)
if "requests" not in data:
print("Error: JSON must contain 'requests' array")
exit(1)
print(f"Loaded {len(data['requests'])} inference requests")
resources:
requests:
memory: 256Mi
cpu: 100m
- name: inference
inputs:
parameters:
- name: use-rag
- name: max-tokens
- name: temperature
artifacts:
- name: inputs
path: /tmp/inputs
outputs:
artifacts:
- name: results
path: /tmp/results
container:
image: python:3.13-slim
command: [python]
args:
- -c
- |
import subprocess
subprocess.run(["pip", "install", "httpx", "pymilvus", "-q"], check=True)
import json
import httpx
from pathlib import Path
from typing import List, Dict
# Configuration
VLLM_URL = "http://llm-draft.ai-ml.svc.cluster.local:8000"
EMBEDDINGS_URL = "http://embeddings-predictor.ai-ml.svc.cluster.local"
RERANKER_URL = "http://reranker-predictor.ai-ml.svc.cluster.local"
MILVUS_HOST = "milvus.ai-ml.svc.cluster.local"
LLM_MODEL = "mistralai/Mistral-7B-Instruct-v0.3"
use_rag = "{{inputs.parameters.use-rag}}" == "true"
max_tokens = int("{{inputs.parameters.max-tokens}}")
temperature = float("{{inputs.parameters.temperature}}")
input_dir = Path("/tmp/inputs")
output_dir = Path("/tmp/results")
output_dir.mkdir(parents=True, exist_ok=True)
# Load inputs
with open(input_dir / "inputs.json") as f:
data = json.load(f)
requests = data["requests"]
print(f"Processing {len(requests)} requests (RAG: {use_rag})")
# Initialize Milvus if using RAG
collection = None
if use_rag:
try:
from pymilvus import connections, Collection, utility
connections.connect(host=MILVUS_HOST, port=19530)
if utility.has_collection("knowledge_base"):
collection = Collection("knowledge_base")
collection.load()
print("Milvus connected")
except Exception as e:
print(f"Milvus connection failed: {e}")
use_rag = False
def get_embeddings(texts: List[str], client: httpx.Client) -> List[List[float]]:
response = client.post(
f"{EMBEDDINGS_URL}/embeddings",
json={"input": texts, "model": "bge"}
)
result = response.json()
return [d["embedding"] for d in result.get("data", [])]
def search_milvus(embedding: List[float]) -> List[Dict]:
results = collection.search(
data=[embedding],
anns_field="embedding",
param={"metric_type": "COSINE", "params": {"ef": 64}},
limit=5,
output_fields=["text", "source"]
)
docs = []
for hits in results:
for hit in hits:
docs.append({
"text": hit.entity.get("text", ""),
"source": hit.entity.get("source", ""),
"score": hit.score
})
return docs
def rerank(query: str, documents: List[str], client: httpx.Client) -> List[Dict]:
response = client.post(
f"{RERANKER_URL}/v1/rerank",
json={"query": query, "documents": documents}
)
return response.json().get("results", [])
# Process requests
results = []
with httpx.Client(timeout=120.0) as client:
for i, req in enumerate(requests):
query = req.get("text", req.get("query", ""))
req_id = req.get("id", str(i))
print(f"Processing {i+1}/{len(requests)}: {query[:50]}...")
context = ""
rag_sources = []
if use_rag and collection:
try:
# Get embeddings and search
embeddings = get_embeddings([query], client)
if embeddings:
docs = search_milvus(embeddings[0])
if docs:
doc_texts = [d["text"] for d in docs]
reranked = rerank(query, doc_texts, client)
sorted_docs = sorted(reranked, key=lambda x: x.get("relevance_score", 0), reverse=True)[:3]
context = "\n\n".join([doc_texts[d["index"]] for d in sorted_docs])
rag_sources = [docs[d["index"]].get("source", "") for d in sorted_docs]
except Exception as e:
print(f" RAG failed: {e}")
# Generate response
try:
messages = [{"role": "system", "content": "You are a helpful AI assistant."}]
if context:
messages.append({"role": "user", "content": f"Context:\n{context}\n\nQuestion: {query}"})
else:
messages.append({"role": "user", "content": query})
response = client.post(
f"{VLLM_URL}/v1/chat/completions",
json={
"model": LLM_MODEL,
"messages": messages,
"max_tokens": max_tokens,
"temperature": temperature
}
)
result = response.json()
answer = result["choices"][0]["message"]["content"]
except Exception as e:
answer = f"Error: {e}"
results.append({
"id": req_id,
"query": query,
"response": answer,
"used_rag": bool(context),
"rag_sources": rag_sources
})
# Save results
with open(output_dir / "results.json", "w") as f:
json.dump({"results": results}, f, indent=2)
print(f"Completed {len(results)} inferences")
if collection:
from pymilvus import connections
connections.disconnect("default")
envFrom:
- configMapRef:
name: ai-services-config
resources:
requests:
memory: 1Gi
cpu: 500m
- name: upload-output
inputs:
parameters:
- name: output-url
artifacts:
- name: results
path: /tmp/results
container:
image: python:3.13-slim
command: [python]
args:
- -c
- |
import subprocess
subprocess.run(["pip", "install", "boto3", "-q"], check=True)
import boto3
from pathlib import Path
output_url = "{{inputs.parameters.output-url}}"
results_file = Path("/tmp/results/results.json")
print(f"Uploading results to: {output_url}")
if output_url.startswith("s3://"):
s3 = boto3.client("s3")
bucket, key = output_url[5:].split("/", 1)
s3.upload_file(str(results_file), bucket, key)
print("Upload complete")
else:
print(f"Unsupported URL scheme: {output_url}")
exit(1)
resources:
requests:
memory: 256Mi
cpu: 100m